Compare commits

...

4 Commits

Author SHA1 Message Date
sayakpaul
e7a4e40181 fix fa4 integration 2026-04-10 20:31:56 +05:30
Akshan Krithick
4548e68e80 Cache RoPE freqs on device to avoid repeated CPU-GPU copy in QwenImage (#13406)
* Cache RoPE freqs on device to avoid repeated CPU-GPU copy in QwenImage

* Apply style fixes

* use lru_cache_unless_export

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-04-09 23:41:50 -07:00
Chenyang Zhu
b80d3f6872 fix(qwen-image dreambooth): correct prompt embed repeats when using --with_prior_preservation (#13396)
fix(qwen): correct prompt embed repeats with prior preservation

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-04-10 10:17:06 +05:30
Chenyang Zhu
acc07f5cda Handle prompt embedding concat in Qwen dreambooth example (#13387)
* Handle prompt embedding concat in Qwen dreambooth example

* remove wandb config

* Apply style fixes

* add a comment on how this is only relevant during prior preservation.

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-04-10 09:13:32 +05:30
3 changed files with 96 additions and 14 deletions

View File

@@ -906,6 +906,68 @@ class PromptDataset(Dataset):
return example
# These helpers only matter for prior preservation, where instance and class prompt
# embedding batches are concatenated and may not share the same mask/sequence length.
def _materialize_prompt_embedding_mask(
prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None
) -> torch.Tensor:
"""Return a dense mask tensor for a prompt embedding batch."""
batch_size, seq_len = prompt_embeds.shape[:2]
if prompt_embeds_mask is None:
return torch.ones((batch_size, seq_len), dtype=torch.long, device=prompt_embeds.device)
if prompt_embeds_mask.shape != (batch_size, seq_len):
raise ValueError(
f"`prompt_embeds_mask` shape {prompt_embeds_mask.shape} must match prompt embeddings shape "
f"({batch_size}, {seq_len})."
)
return prompt_embeds_mask.to(device=prompt_embeds.device)
def _pad_prompt_embedding_pair(
prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None, target_seq_len: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""Pad one prompt embedding batch and its mask to a shared sequence length."""
prompt_embeds_mask = _materialize_prompt_embedding_mask(prompt_embeds, prompt_embeds_mask)
pad_width = target_seq_len - prompt_embeds.shape[1]
if pad_width <= 0:
return prompt_embeds, prompt_embeds_mask
prompt_embeds = torch.cat(
[prompt_embeds, prompt_embeds.new_zeros(prompt_embeds.shape[0], pad_width, prompt_embeds.shape[2])], dim=1
)
prompt_embeds_mask = torch.cat(
[prompt_embeds_mask, prompt_embeds_mask.new_zeros(prompt_embeds_mask.shape[0], pad_width)], dim=1
)
return prompt_embeds, prompt_embeds_mask
def concat_prompt_embedding_batches(
*prompt_embedding_pairs: tuple[torch.Tensor, torch.Tensor | None],
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Concatenate prompt embedding batches while handling missing masks and length mismatches."""
if not prompt_embedding_pairs:
raise ValueError("At least one prompt embedding pair must be provided.")
target_seq_len = max(prompt_embeds.shape[1] for prompt_embeds, _ in prompt_embedding_pairs)
padded_pairs = [
_pad_prompt_embedding_pair(prompt_embeds, prompt_embeds_mask, target_seq_len)
for prompt_embeds, prompt_embeds_mask in prompt_embedding_pairs
]
merged_prompt_embeds = torch.cat([prompt_embeds for prompt_embeds, _ in padded_pairs], dim=0)
merged_mask = torch.cat([prompt_embeds_mask for _, prompt_embeds_mask in padded_pairs], dim=0)
if merged_mask.all():
return merged_prompt_embeds, None
return merged_prompt_embeds, merged_mask
def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
@@ -1320,8 +1382,10 @@ def main(args):
prompt_embeds = instance_prompt_embeds
prompt_embeds_mask = instance_prompt_embeds_mask
if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_embeds], dim=0)
prompt_embeds_mask = torch.cat([prompt_embeds_mask, class_prompt_embeds_mask], dim=0)
prompt_embeds, prompt_embeds_mask = concat_prompt_embedding_batches(
(instance_prompt_embeds, instance_prompt_embeds_mask),
(class_prompt_embeds, class_prompt_embeds_mask),
)
# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
@@ -1465,7 +1529,10 @@ def main(args):
prompt_embeds = prompt_embeds_cache[step]
prompt_embeds_mask = prompt_embeds_mask_cache[step]
else:
num_repeat_elements = len(prompts)
# With prior preservation, prompt_embeds already contains [instance, class] embeddings
# from the cat above, but collate_fn also doubles the prompts list. Use half the
# prompts count to avoid a 2x over-repeat that produces more embeddings than latents.
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
if prompt_embeds_mask is not None:
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)

View File

@@ -540,7 +540,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
)
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_version(">=", "0.12.3"):
raise RuntimeError(
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
)

View File

@@ -233,6 +233,11 @@ class QwenEmbedRope(nn.Module):
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
@lru_cache_unless_export(maxsize=None)
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
"""Return pos_freqs and neg_freqs on the given device."""
return self.pos_freqs.to(device), self.neg_freqs.to(device)
def forward(
self,
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
@@ -300,8 +305,9 @@ class QwenEmbedRope(nn.Module):
max_vid_index = max(height, width, max_vid_index)
max_txt_seq_len_int = int(max_txt_seq_len)
# Create device-specific copy for text freqs without modifying self.pos_freqs
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
# Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
pos_freqs_device, _ = self._get_device_freqs(device)
txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
@@ -311,8 +317,9 @@ class QwenEmbedRope(nn.Module):
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
) -> torch.Tensor:
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
pos_freqs, neg_freqs = (
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
)
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
@@ -367,6 +374,11 @@ class QwenEmbedLayer3DRope(nn.Module):
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
@lru_cache_unless_export(maxsize=None)
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
"""Return pos_freqs and neg_freqs on the given device."""
return self.pos_freqs.to(device), self.neg_freqs.to(device)
def forward(
self,
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
@@ -421,8 +433,9 @@ class QwenEmbedLayer3DRope(nn.Module):
max_vid_index = max(max_vid_index, layer_num)
max_txt_seq_len_int = int(max_txt_seq_len)
# Create device-specific copy for text freqs without modifying self.pos_freqs
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
# Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
pos_freqs_device, _ = self._get_device_freqs(device)
txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
@@ -430,8 +443,9 @@ class QwenEmbedLayer3DRope(nn.Module):
@lru_cache_unless_export(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
pos_freqs, neg_freqs = (
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
)
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
@@ -452,8 +466,9 @@ class QwenEmbedLayer3DRope(nn.Module):
@lru_cache_unless_export(maxsize=None)
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
pos_freqs, neg_freqs = (
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
)
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)