mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-25 13:06:06 +08:00
Compare commits
1 Commits
modular-lo
...
qwenimage-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
180813069e |
@@ -546,6 +546,7 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
batch_cfg: bool = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -712,6 +713,14 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
if batch_cfg:
|
||||
target_len = max(negative_prompt_embeds.size(1), prompt_embeds.size(1))
|
||||
negative_prompt_embeds = self._pad_to_len(negative_prompt_embeds, target_len, pad_value=0.0)
|
||||
prompt_embeds = self._pad_to_len(prompt_embeds, target_len, pad_value=0.0)
|
||||
negative_prompt_embeds_mask = self._pad_to_len(negative_prompt_embeds_mask, target_len, pad_value=0)
|
||||
prompt_embeds_mask = self._pad_to_len(prompt_embeds_mask, target_len, pad_value=0)
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_embeds_mask = torch.cat([negative_prompt_embeds_mask, prompt_embeds_mask], dim=0)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
@@ -732,7 +741,9 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
(1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2),
|
||||
]
|
||||
] * batch_size
|
||||
|
||||
if batch_cfg:
|
||||
img_shapes = img_shapes * 2
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
image_seq_len = latents.shape[1]
|
||||
@@ -771,9 +782,10 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
if not batch_cfg:
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
@@ -787,9 +799,14 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
latent_model_input = latents
|
||||
if image_latents is not None:
|
||||
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
||||
if batch_cfg:
|
||||
latent_model_input = torch.cat([latent_model_input] * 2)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
if batch_cfg:
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
||||
else:
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
@@ -802,22 +819,25 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred[:, : latents.size(1)]
|
||||
noise_pred = noise_pred[:, : latents.size(1)]
|
||||
|
||||
if do_true_cfg:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
|
||||
if not batch_cfg:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
|
||||
else:
|
||||
neg_noise_pred, noise_pred = noise_pred.chunk(2, dim=0)
|
||||
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
||||
@@ -874,3 +894,23 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
return (image,)
|
||||
|
||||
return QwenImagePipelineOutput(images=image)
|
||||
|
||||
@staticmethod
|
||||
def _pad_to_len(x, target_len, pad_value=0.0):
|
||||
# x: [B, S, D] or [B, S]
|
||||
if x.dim() == 3: # embeds
|
||||
B, S, D = x.shape
|
||||
if S == target_len:
|
||||
return x
|
||||
out = x.new_full((B, target_len, D), pad_value)
|
||||
out[:, :S, :] = x
|
||||
return out
|
||||
elif x.dim() == 2: # mask
|
||||
B, S = x.shape
|
||||
if S == target_len:
|
||||
return x
|
||||
out = x.new_zeros((B, target_len), dtype=x.dtype)
|
||||
out[:, :S] = x
|
||||
return out
|
||||
else:
|
||||
raise ValueError("Unexpected tensor rank")
|
||||
|
||||
Reference in New Issue
Block a user