Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
2e383fc901 batched cfg for flux. 2024-09-20 14:57:52 +05:30

View File

@@ -289,80 +289,107 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
negative_prompt: Union[str, List[str]] = None,
negative_prompt_2: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
do_true_cfg: bool = False,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in all text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
# Set LoRA scale if applicable
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if do_true_cfg and negative_prompt is not None:
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_batch_size = len(negative_prompt)
if negative_batch_size != batch_size:
raise ValueError(
f"Negative prompt batch size ({negative_batch_size}) does not match prompt batch size ({batch_size})"
)
# Concatenate prompts
prompts = prompt + negative_prompt
prompts_2 = (
prompt_2 + negative_prompt_2 if prompt_2 is not None and negative_prompt_2 is not None else None
)
else:
prompts = prompt
prompts_2 = prompt_2
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
if prompts_2 is None:
prompts_2 = prompts
# We only use the pooled prompt output from the CLIPTextModel
# Get pooled prompt embeddings from CLIPTextModel
pooled_prompt_embeds = self._get_clip_prompt_embeds(
prompt=prompt,
prompt=prompts,
device=device,
num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt_2,
prompt=prompts_2,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
if do_true_cfg and negative_prompt is not None:
# Split embeddings back into positive and negative parts
total_batch_size = batch_size * num_images_per_prompt
positive_indices = slice(0, total_batch_size)
negative_indices = slice(total_batch_size, 2 * total_batch_size)
positive_pooled_prompt_embeds = pooled_prompt_embeds[positive_indices]
negative_pooled_prompt_embeds = pooled_prompt_embeds[negative_indices]
positive_prompt_embeds = prompt_embeds[positive_indices]
negative_prompt_embeds = prompt_embeds[negative_indices]
pooled_prompt_embeds = positive_pooled_prompt_embeds
prompt_embeds = positive_prompt_embeds
# Unscale LoRA layers
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
if do_true_cfg and negative_prompt is not None:
negative_text_ids = text_ids.clone()
return (
prompt_embeds,
pooled_prompt_embeds,
text_ids,
negative_prompt_embeds,
negative_pooled_prompt_embeds,
negative_text_ids,
)
else:
return prompt_embeds, pooled_prompt_embeds, text_ids, None, None, None
def check_inputs(
self,
@@ -687,38 +714,52 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
do_true_cfg = true_cfg > 1 and negative_prompt is not None
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
negative_prompt_embeds,
negative_pooled_prompt_embeds,
negative_text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
do_true_cfg=do_true_cfg,
)
# perform "real" CFG as suggested for distilled Flux models in https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md
do_true_cfg = true_cfg > 1 and negative_prompt is not None
# do_true_cfg = true_cfg > 1 and negative_prompt is not None
# if do_true_cfg:
# (
# negative_prompt_embeds,
# negative_pooled_prompt_embeds,
# negative_text_ids,
# ) = self.encode_prompt(
# prompt=negative_prompt,
# prompt_2=negative_prompt_2,
# prompt_embeds=negative_prompt_embeds,
# pooled_prompt_embeds=negative_pooled_prompt_embeds,
# device=device,
# num_images_per_prompt=num_images_per_prompt,
# max_sequence_length=max_sequence_length,
# lora_scale=lora_scale,
# )
if do_true_cfg:
(
negative_prompt_embeds,
negative_pooled_prompt_embeds,
negative_text_ids,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# Concatenate embeddings
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
text_ids = torch.cat([negative_text_ids, text_ids], dim=0)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
@@ -732,6 +773,9 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
generator,
latents,
)
if do_true_cfg:
latent_image_ids = latent_image_ids.repeat(prompt_embeds.shape[0], 1, 1)
# expand the latents, too?
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
@@ -767,11 +811,13 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if do_true_cfg else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
noise_pred = self.transformer(
hidden_states=latents,
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
@@ -783,18 +829,7 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
)[0]
if do_true_cfg:
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
neg_noise_pred, noise_pred = noise_pred.chunk(2)
noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1