Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
180813069e batched cfg implementation for qwenimage edit. 2025-09-04 10:18:27 +05:30

View File

@@ -546,6 +546,7 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
batch_cfg: bool = False,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -712,6 +713,14 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
if batch_cfg:
target_len = max(negative_prompt_embeds.size(1), prompt_embeds.size(1))
negative_prompt_embeds = self._pad_to_len(negative_prompt_embeds, target_len, pad_value=0.0)
prompt_embeds = self._pad_to_len(prompt_embeds, target_len, pad_value=0.0)
negative_prompt_embeds_mask = self._pad_to_len(negative_prompt_embeds_mask, target_len, pad_value=0)
prompt_embeds_mask = self._pad_to_len(prompt_embeds_mask, target_len, pad_value=0)
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_embeds_mask = torch.cat([negative_prompt_embeds_mask, prompt_embeds_mask], dim=0)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
@@ -732,7 +741,9 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
(1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2),
]
] * batch_size
if batch_cfg:
img_shapes = img_shapes * 2
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
@@ -771,9 +782,10 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
self._attention_kwargs = {}
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
if not batch_cfg:
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
# 6. Denoising loop
self.scheduler.set_begin_index(0)
@@ -787,9 +799,14 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
latent_model_input = latents
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1)
if batch_cfg:
latent_model_input = torch.cat([latent_model_input] * 2)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
if batch_cfg:
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
else:
timestep = t.expand(latents.shape[0]).to(latents.dtype)
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
@@ -802,22 +819,25 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]
noise_pred = noise_pred[:, : latents.size(1)]
if do_true_cfg:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
if not batch_cfg:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
else:
neg_noise_pred, noise_pred = noise_pred.chunk(2, dim=0)
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
@@ -874,3 +894,23 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
return (image,)
return QwenImagePipelineOutput(images=image)
@staticmethod
def _pad_to_len(x, target_len, pad_value=0.0):
# x: [B, S, D] or [B, S]
if x.dim() == 3: # embeds
B, S, D = x.shape
if S == target_len:
return x
out = x.new_full((B, target_len, D), pad_value)
out[:, :S, :] = x
return out
elif x.dim() == 2: # mask
B, S = x.shape
if S == target_len:
return x
out = x.new_zeros((B, target_len), dtype=x.dtype)
out[:, :S] = x
return out
else:
raise ValueError("Unexpected tensor rank")