Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
51fadecc66 feat: make QwenImage family fully compilable again.
Co-authored-by: apolinario <joaopaulo.passos@gmail.com>
Co-authored-by: cbensimon <charles@huggingface.co>
2025-08-27 09:40:23 +02:00
2 changed files with 9 additions and 6 deletions

View File

@@ -557,6 +557,7 @@ class QwenImageTransformer2DModel(
attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
return_dict: bool = True,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`QwenTransformer2DModel`] forward method.
@@ -611,8 +612,8 @@ class QwenImageTransformer2DModel(
if guidance is None
else self.time_text_embed(timestep, guidance, hidden_states)
)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
if image_rotary_emb is None:
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:

View File

@@ -631,6 +631,10 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
neg_image_rotary_emb = None
if do_true_cfg:
neg_image_rotary_emb = self.transformer.pos_embed(img_shapes, negative_txt_seq_lens, device=latents.device)
# 6. Denoising loop
self.scheduler.set_begin_index(0)
@@ -649,8 +653,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
guidance=guidance,
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
image_rotary_emb=image_rotary_emb,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
@@ -663,8 +666,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
guidance=guidance,
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
image_rotary_emb=neg_image_rotary_emb,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]