mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-11 02:01:57 +08:00
Compare commits
7 Commits
fix-review
...
fix-textua
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
793d24d327 | ||
|
|
a753642a50 | ||
|
|
d4386f4231 | ||
|
|
896fec351b | ||
|
|
4548e68e80 | ||
|
|
b80d3f6872 | ||
|
|
acc07f5cda |
1
.github/workflows/pr_dependency_test.yml
vendored
1
.github/workflows/pr_dependency_test.yml
vendored
@@ -6,6 +6,7 @@ on:
|
||||
- main
|
||||
paths:
|
||||
- "src/diffusers/**.py"
|
||||
- "tests/**.py"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
@@ -6,6 +6,7 @@ on:
|
||||
- main
|
||||
paths:
|
||||
- "src/diffusers/**.py"
|
||||
- "tests/**.py"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
@@ -26,7 +27,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -e .
|
||||
pip install torch torchvision torchaudio pytest
|
||||
pip install torch pytest
|
||||
- name: Check for soft dependencies
|
||||
run: |
|
||||
pytest tests/others/test_dependencies.py
|
||||
|
||||
@@ -895,9 +895,8 @@ class TokenEmbeddingsHandler:
|
||||
self.train_ids_t5 = tokenizer.convert_tokens_to_ids(self.inserting_toks)
|
||||
|
||||
# random initialization of new tokens
|
||||
embeds = (
|
||||
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
|
||||
)
|
||||
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
|
||||
std_token_embedding = embeds.weight.data.std()
|
||||
|
||||
logger.info(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
|
||||
@@ -905,9 +904,7 @@ class TokenEmbeddingsHandler:
|
||||
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
|
||||
# if initializer_concept are not provided, token embeddings are initialized randomly
|
||||
if args.initializer_concept is None:
|
||||
hidden_size = (
|
||||
text_encoder.text_model.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size
|
||||
)
|
||||
hidden_size = text_module.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size
|
||||
embeds.weight.data[train_ids] = (
|
||||
torch.randn(len(train_ids), hidden_size).to(device=self.device).to(dtype=self.dtype)
|
||||
* std_token_embedding
|
||||
@@ -940,7 +937,8 @@ class TokenEmbeddingsHandler:
|
||||
idx_to_text_encoder_name = {0: "clip_l", 1: "t5"}
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
|
||||
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
|
||||
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.shared
|
||||
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
|
||||
new_token_embeddings = embeds.weight.data[train_ids]
|
||||
|
||||
@@ -962,7 +960,8 @@ class TokenEmbeddingsHandler:
|
||||
@torch.no_grad()
|
||||
def retract_embeddings(self):
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
|
||||
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.shared
|
||||
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
|
||||
embeds.weight.data[index_no_updates] = (
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
|
||||
@@ -2112,7 +2111,8 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
elif args.train_text_encoder_ti: # textual inversion / pivotal tuning
|
||||
text_encoder_one.train()
|
||||
if args.enable_t5_ti:
|
||||
|
||||
@@ -763,19 +763,28 @@ class TokenEmbeddingsHandler:
|
||||
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
|
||||
|
||||
# random initialization of new tokens
|
||||
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
|
||||
std_token_embedding = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.std()
|
||||
|
||||
print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
|
||||
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
|
||||
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[self.train_ids] = (
|
||||
torch.randn(
|
||||
len(self.train_ids),
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).config.hidden_size,
|
||||
)
|
||||
.to(device=self.device)
|
||||
.to(dtype=self.dtype)
|
||||
* std_token_embedding
|
||||
)
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"] = (
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
)
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.clone()
|
||||
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
|
||||
|
||||
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
@@ -794,10 +803,14 @@ class TokenEmbeddingsHandler:
|
||||
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 - TODO - change for sd
|
||||
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
|
||||
self.tokenizers[0]
|
||||
), "Tokenizers should be the same."
|
||||
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
|
||||
assert (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.shape[0] == len(self.tokenizers[0]), (
|
||||
"Tokenizers should be the same."
|
||||
)
|
||||
new_token_embeddings = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[self.train_ids]
|
||||
|
||||
# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
|
||||
# text_encoder 1) to keep compatible with the ecosystem.
|
||||
@@ -819,7 +832,9 @@ class TokenEmbeddingsHandler:
|
||||
def retract_embeddings(self):
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = (
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_no_updates] = (
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
|
||||
.to(device=text_encoder.device)
|
||||
.to(dtype=text_encoder.dtype)
|
||||
@@ -830,11 +845,15 @@ class TokenEmbeddingsHandler:
|
||||
std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
|
||||
|
||||
index_updates = ~index_no_updates
|
||||
new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates]
|
||||
new_embeddings = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_updates]
|
||||
off_ratio = std_token_embedding / new_embeddings.std()
|
||||
|
||||
new_embeddings = new_embeddings * (off_ratio**0.1)
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_updates] = new_embeddings
|
||||
|
||||
|
||||
class DreamBoothDataset(Dataset):
|
||||
@@ -1704,7 +1723,8 @@ def main(args):
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.text_model.embeddings.requires_grad_(True)
|
||||
_te_one = text_encoder_one
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
|
||||
unet.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
|
||||
@@ -929,19 +929,28 @@ class TokenEmbeddingsHandler:
|
||||
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
|
||||
|
||||
# random initialization of new tokens
|
||||
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
|
||||
std_token_embedding = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.std()
|
||||
|
||||
print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
|
||||
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
|
||||
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[self.train_ids] = (
|
||||
torch.randn(
|
||||
len(self.train_ids),
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).config.hidden_size,
|
||||
)
|
||||
.to(device=self.device)
|
||||
.to(dtype=self.dtype)
|
||||
* std_token_embedding
|
||||
)
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"] = (
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
||||
)
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.clone()
|
||||
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
|
||||
|
||||
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
@@ -959,10 +968,14 @@ class TokenEmbeddingsHandler:
|
||||
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
|
||||
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
|
||||
self.tokenizers[0]
|
||||
), "Tokenizers should be the same."
|
||||
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
|
||||
assert (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data.shape[0] == len(self.tokenizers[0]), (
|
||||
"Tokenizers should be the same."
|
||||
)
|
||||
new_token_embeddings = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[self.train_ids]
|
||||
|
||||
# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
|
||||
# text_encoder 1) to keep compatible with the ecosystem.
|
||||
@@ -984,7 +997,9 @@ class TokenEmbeddingsHandler:
|
||||
def retract_embeddings(self):
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = (
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_no_updates] = (
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
|
||||
.to(device=text_encoder.device)
|
||||
.to(dtype=text_encoder.dtype)
|
||||
@@ -995,11 +1010,15 @@ class TokenEmbeddingsHandler:
|
||||
std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
|
||||
|
||||
index_updates = ~index_no_updates
|
||||
new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates]
|
||||
new_embeddings = (
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_updates]
|
||||
off_ratio = std_token_embedding / new_embeddings.std()
|
||||
|
||||
new_embeddings = new_embeddings * (off_ratio**0.1)
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
|
||||
(
|
||||
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
).embeddings.token_embedding.weight.data[index_updates] = new_embeddings
|
||||
|
||||
|
||||
class DreamBoothDataset(Dataset):
|
||||
@@ -2083,8 +2102,10 @@ def main(args):
|
||||
text_encoder_two.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
if args.train_text_encoder:
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = accelerator.unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
_te_two = accelerator.unwrap_model(text_encoder_two)
|
||||
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if pivoted:
|
||||
|
||||
@@ -874,10 +874,11 @@ def main(args):
|
||||
token_embeds[x] = token_embeds[y]
|
||||
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
params_to_freeze = itertools.chain(
|
||||
text_encoder.text_model.encoder.parameters(),
|
||||
text_encoder.text_model.final_layer_norm.parameters(),
|
||||
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
||||
text_module.encoder.parameters(),
|
||||
text_module.final_layer_norm.parameters(),
|
||||
text_module.embeddings.position_embedding.parameters(),
|
||||
)
|
||||
freeze_params(params_to_freeze)
|
||||
########################################################
|
||||
|
||||
@@ -1691,7 +1691,8 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
|
||||
@@ -1896,7 +1896,8 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
|
||||
@@ -906,6 +906,68 @@ class PromptDataset(Dataset):
|
||||
return example
|
||||
|
||||
|
||||
# These helpers only matter for prior preservation, where instance and class prompt
|
||||
# embedding batches are concatenated and may not share the same mask/sequence length.
|
||||
def _materialize_prompt_embedding_mask(
|
||||
prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
"""Return a dense mask tensor for a prompt embedding batch."""
|
||||
batch_size, seq_len = prompt_embeds.shape[:2]
|
||||
|
||||
if prompt_embeds_mask is None:
|
||||
return torch.ones((batch_size, seq_len), dtype=torch.long, device=prompt_embeds.device)
|
||||
|
||||
if prompt_embeds_mask.shape != (batch_size, seq_len):
|
||||
raise ValueError(
|
||||
f"`prompt_embeds_mask` shape {prompt_embeds_mask.shape} must match prompt embeddings shape "
|
||||
f"({batch_size}, {seq_len})."
|
||||
)
|
||||
|
||||
return prompt_embeds_mask.to(device=prompt_embeds.device)
|
||||
|
||||
|
||||
def _pad_prompt_embedding_pair(
|
||||
prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None, target_seq_len: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Pad one prompt embedding batch and its mask to a shared sequence length."""
|
||||
prompt_embeds_mask = _materialize_prompt_embedding_mask(prompt_embeds, prompt_embeds_mask)
|
||||
pad_width = target_seq_len - prompt_embeds.shape[1]
|
||||
|
||||
if pad_width <= 0:
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
prompt_embeds = torch.cat(
|
||||
[prompt_embeds, prompt_embeds.new_zeros(prompt_embeds.shape[0], pad_width, prompt_embeds.shape[2])], dim=1
|
||||
)
|
||||
prompt_embeds_mask = torch.cat(
|
||||
[prompt_embeds_mask, prompt_embeds_mask.new_zeros(prompt_embeds_mask.shape[0], pad_width)], dim=1
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
|
||||
def concat_prompt_embedding_batches(
|
||||
*prompt_embedding_pairs: tuple[torch.Tensor, torch.Tensor | None],
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Concatenate prompt embedding batches while handling missing masks and length mismatches."""
|
||||
if not prompt_embedding_pairs:
|
||||
raise ValueError("At least one prompt embedding pair must be provided.")
|
||||
|
||||
target_seq_len = max(prompt_embeds.shape[1] for prompt_embeds, _ in prompt_embedding_pairs)
|
||||
padded_pairs = [
|
||||
_pad_prompt_embedding_pair(prompt_embeds, prompt_embeds_mask, target_seq_len)
|
||||
for prompt_embeds, prompt_embeds_mask in prompt_embedding_pairs
|
||||
]
|
||||
|
||||
merged_prompt_embeds = torch.cat([prompt_embeds for prompt_embeds, _ in padded_pairs], dim=0)
|
||||
merged_mask = torch.cat([prompt_embeds_mask for _, prompt_embeds_mask in padded_pairs], dim=0)
|
||||
|
||||
if merged_mask.all():
|
||||
return merged_prompt_embeds, None
|
||||
|
||||
return merged_prompt_embeds, merged_mask
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.report_to == "wandb" and args.hub_token is not None:
|
||||
raise ValueError(
|
||||
@@ -1320,8 +1382,10 @@ def main(args):
|
||||
prompt_embeds = instance_prompt_embeds
|
||||
prompt_embeds_mask = instance_prompt_embeds_mask
|
||||
if args.with_prior_preservation:
|
||||
prompt_embeds = torch.cat([prompt_embeds, class_prompt_embeds], dim=0)
|
||||
prompt_embeds_mask = torch.cat([prompt_embeds_mask, class_prompt_embeds_mask], dim=0)
|
||||
prompt_embeds, prompt_embeds_mask = concat_prompt_embedding_batches(
|
||||
(instance_prompt_embeds, instance_prompt_embeds_mask),
|
||||
(class_prompt_embeds, class_prompt_embeds_mask),
|
||||
)
|
||||
|
||||
# if cache_latents is set to True, we encode images to latents and store them.
|
||||
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
|
||||
@@ -1465,7 +1529,10 @@ def main(args):
|
||||
prompt_embeds = prompt_embeds_cache[step]
|
||||
prompt_embeds_mask = prompt_embeds_mask_cache[step]
|
||||
else:
|
||||
num_repeat_elements = len(prompts)
|
||||
# With prior preservation, prompt_embeds already contains [instance, class] embeddings
|
||||
# from the cat above, but collate_fn also doubles the prompts list. Use half the
|
||||
# prompts count to avoid a 2x over-repeat that produces more embeddings than latents.
|
||||
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
|
||||
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
|
||||
if prompt_embeds_mask is not None:
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)
|
||||
|
||||
@@ -1719,8 +1719,10 @@ def main(args):
|
||||
text_encoder_two.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = accelerator.unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
_te_two = accelerator.unwrap_model(text_encoder_two)
|
||||
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
|
||||
@@ -1661,8 +1661,10 @@ def main(args):
|
||||
text_encoder_two.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
|
||||
_te_one = accelerator.unwrap_model(text_encoder_one)
|
||||
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
|
||||
_te_two = accelerator.unwrap_model(text_encoder_two)
|
||||
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
|
||||
@@ -702,9 +702,10 @@ def main():
|
||||
vae.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_encoder.text_model.encoder.requires_grad_(False)
|
||||
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
|
||||
text_module.encoder.requires_grad_(False)
|
||||
text_module.final_layer_norm.requires_grad_(False)
|
||||
text_module.embeddings.position_embedding.requires_grad_(False)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
# Keep unet in train mode if we are using gradient checkpointing to save memory.
|
||||
|
||||
@@ -717,12 +717,14 @@ def main():
|
||||
unet.requires_grad_(False)
|
||||
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_encoder_1.text_model.encoder.requires_grad_(False)
|
||||
text_encoder_1.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
text_encoder_2.text_model.encoder.requires_grad_(False)
|
||||
text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
text_module_1 = text_encoder_1.text_model if hasattr(text_encoder_1, "text_model") else text_encoder_1
|
||||
text_module_1.encoder.requires_grad_(False)
|
||||
text_module_1.final_layer_norm.requires_grad_(False)
|
||||
text_module_1.embeddings.position_embedding.requires_grad_(False)
|
||||
text_module_2 = text_encoder_2.text_model if hasattr(text_encoder_2, "text_model") else text_encoder_2
|
||||
text_module_2.encoder.requires_grad_(False)
|
||||
text_module_2.final_layer_norm.requires_grad_(False)
|
||||
text_module_2.embeddings.position_embedding.requires_grad_(False)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
text_encoder_1.gradient_checkpointing_enable()
|
||||
@@ -767,8 +769,12 @@ def main():
|
||||
optimizer = optimizer_class(
|
||||
# only optimize the embeddings
|
||||
[
|
||||
text_encoder_1.text_model.embeddings.token_embedding.weight,
|
||||
text_encoder_2.text_model.embeddings.token_embedding.weight,
|
||||
(
|
||||
text_encoder_1.text_model if hasattr(text_encoder_1, "text_model") else text_encoder_1
|
||||
).embeddings.token_embedding.weight,
|
||||
(
|
||||
text_encoder_2.text_model if hasattr(text_encoder_2, "text_model") else text_encoder_2
|
||||
).embeddings.token_embedding.weight,
|
||||
],
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
|
||||
@@ -233,6 +233,11 @@ class QwenEmbedRope(nn.Module):
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Return pos_freqs and neg_freqs on the given device."""
|
||||
return self.pos_freqs.to(device), self.neg_freqs.to(device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
|
||||
@@ -300,8 +305,9 @@ class QwenEmbedRope(nn.Module):
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_txt_seq_len_int = int(max_txt_seq_len)
|
||||
# Create device-specific copy for text freqs without modifying self.pos_freqs
|
||||
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
|
||||
# Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
|
||||
pos_freqs_device, _ = self._get_device_freqs(device)
|
||||
txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
@@ -311,8 +317,9 @@ class QwenEmbedRope(nn.Module):
|
||||
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
|
||||
) -> torch.Tensor:
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
|
||||
pos_freqs, neg_freqs = (
|
||||
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
|
||||
)
|
||||
|
||||
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
@@ -367,6 +374,11 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Return pos_freqs and neg_freqs on the given device."""
|
||||
return self.pos_freqs.to(device), self.neg_freqs.to(device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
|
||||
@@ -421,8 +433,9 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
|
||||
max_vid_index = max(max_vid_index, layer_num)
|
||||
max_txt_seq_len_int = int(max_txt_seq_len)
|
||||
# Create device-specific copy for text freqs without modifying self.pos_freqs
|
||||
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
|
||||
# Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
|
||||
pos_freqs_device, _ = self._get_device_freqs(device)
|
||||
txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
@@ -430,8 +443,9 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
|
||||
pos_freqs, neg_freqs = (
|
||||
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
|
||||
)
|
||||
|
||||
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
@@ -452,8 +466,9 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
|
||||
pos_freqs, neg_freqs = (
|
||||
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
|
||||
)
|
||||
|
||||
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
@@ -5,10 +5,13 @@ import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from torchvision.transforms.functional import normalize, resize
|
||||
|
||||
from ...utils import get_logger, load_image
|
||||
from ...utils import get_logger, is_torchvision_available, load_image
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from torchvision.transforms.functional import normalize, resize
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -13,16 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
from importlib import import_module
|
||||
|
||||
import pytest
|
||||
|
||||
class DependencyTester(unittest.TestCase):
|
||||
|
||||
class TestDependencies:
|
||||
def test_diffusers_import(self):
|
||||
try:
|
||||
import diffusers # noqa: F401
|
||||
except ImportError:
|
||||
assert False
|
||||
import diffusers # noqa: F401
|
||||
|
||||
def test_backend_registration(self):
|
||||
import diffusers
|
||||
@@ -52,3 +50,36 @@ class DependencyTester(unittest.TestCase):
|
||||
if hasattr(diffusers.pipelines, cls_name):
|
||||
pipeline_folder_module = ".".join(str(cls_module.__module__).split(".")[:3])
|
||||
_ = import_module(pipeline_folder_module, str(cls_name))
|
||||
|
||||
def test_pipeline_module_imports(self):
|
||||
"""Import every pipeline submodule whose dependencies are satisfied,
|
||||
to catch unguarded optional-dep imports (e.g., torchvision).
|
||||
|
||||
Uses inspect.getmembers to discover classes that the lazy loader can
|
||||
actually resolve (same self-filtering as test_pipeline_imports), then
|
||||
imports the full module path instead of truncating to the folder level.
|
||||
"""
|
||||
import diffusers
|
||||
import diffusers.pipelines
|
||||
|
||||
failures = []
|
||||
all_classes = inspect.getmembers(diffusers, inspect.isclass)
|
||||
|
||||
for cls_name, cls_module in all_classes:
|
||||
if not hasattr(diffusers.pipelines, cls_name):
|
||||
continue
|
||||
if "dummy_" in cls_module.__module__:
|
||||
continue
|
||||
|
||||
full_module_path = cls_module.__module__
|
||||
try:
|
||||
import_module(full_module_path)
|
||||
except ImportError as e:
|
||||
failures.append(f"{full_module_path}: {e}")
|
||||
except Exception:
|
||||
# Non-import errors (e.g., missing config) are fine; we only
|
||||
# care about unguarded import statements.
|
||||
pass
|
||||
|
||||
if failures:
|
||||
pytest.fail("Unguarded optional-dependency imports found:\n" + "\n".join(failures))
|
||||
|
||||
Reference in New Issue
Block a user