Compare commits

..

4 Commits

Author SHA1 Message Date
Sayak Paul
4a5d8c37d4 Merge branch 'main' into fix-textual-inversion 2026-04-11 10:05:22 +05:30
Sayak Paul
793d24d327 Merge branch 'main' into fix-textual-inversion 2026-04-10 17:58:32 +02:00
sayakpaul
a753642a50 fix rest 2026-04-10 21:28:20 +05:30
sayakpaul
d4386f4231 fix textual inversion 2026-04-10 19:53:06 +05:30
11 changed files with 112 additions and 57 deletions

View File

@@ -895,9 +895,8 @@ class TokenEmbeddingsHandler:
self.train_ids_t5 = tokenizer.convert_tokens_to_ids(self.inserting_toks)
# random initialization of new tokens
embeds = (
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
)
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
std_token_embedding = embeds.weight.data.std()
logger.info(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
@@ -905,9 +904,7 @@ class TokenEmbeddingsHandler:
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
# if initializer_concept are not provided, token embeddings are initialized randomly
if args.initializer_concept is None:
hidden_size = (
text_encoder.text_model.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size
)
hidden_size = text_module.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size
embeds.weight.data[train_ids] = (
torch.randn(len(train_ids), hidden_size).to(device=self.device).to(dtype=self.dtype)
* std_token_embedding
@@ -940,7 +937,8 @@ class TokenEmbeddingsHandler:
idx_to_text_encoder_name = {0: "clip_l", 1: "t5"}
for idx, text_encoder in enumerate(self.text_encoders):
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.shared
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
new_token_embeddings = embeds.weight.data[train_ids]
@@ -962,7 +960,8 @@ class TokenEmbeddingsHandler:
@torch.no_grad()
def retract_embeddings(self):
for idx, text_encoder in enumerate(self.text_encoders):
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
embeds = text_module.embeddings.token_embedding if idx == 0 else text_encoder.shared
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
embeds.weight.data[index_no_updates] = (
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
@@ -2112,7 +2111,8 @@ def main(args):
if args.train_text_encoder:
text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
_te_one = unwrap_model(text_encoder_one)
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
elif args.train_text_encoder_ti: # textual inversion / pivotal tuning
text_encoder_one.train()
if args.enable_t5_ti:

View File

@@ -763,19 +763,28 @@ class TokenEmbeddingsHandler:
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
# random initialization of new tokens
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
std_token_embedding = (
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data.std()
print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
(
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[self.train_ids] = (
torch.randn(
len(self.train_ids),
(
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).config.hidden_size,
)
.to(device=self.device)
.to(dtype=self.dtype)
* std_token_embedding
)
self.embeddings_settings[f"original_embeddings_{idx}"] = (
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
)
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data.clone()
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
@@ -794,10 +803,14 @@ class TokenEmbeddingsHandler:
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 - TODO - change for sd
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
for idx, text_encoder in enumerate(self.text_encoders):
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
self.tokenizers[0]
), "Tokenizers should be the same."
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
assert (
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data.shape[0] == len(self.tokenizers[0]), (
"Tokenizers should be the same."
)
new_token_embeddings = (
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[self.train_ids]
# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
# text_encoder 1) to keep compatible with the ecosystem.
@@ -819,7 +832,9 @@ class TokenEmbeddingsHandler:
def retract_embeddings(self):
for idx, text_encoder in enumerate(self.text_encoders):
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = (
(
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[index_no_updates] = (
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
.to(device=text_encoder.device)
.to(dtype=text_encoder.dtype)
@@ -830,11 +845,15 @@ class TokenEmbeddingsHandler:
std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
index_updates = ~index_no_updates
new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates]
new_embeddings = (
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[index_updates]
off_ratio = std_token_embedding / new_embeddings.std()
new_embeddings = new_embeddings * (off_ratio**0.1)
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
(
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[index_updates] = new_embeddings
class DreamBoothDataset(Dataset):
@@ -1704,7 +1723,8 @@ def main(args):
text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works
if args.train_text_encoder:
text_encoder_one.text_model.embeddings.requires_grad_(True)
_te_one = text_encoder_one
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
unet.train()
for step, batch in enumerate(train_dataloader):

View File

@@ -929,19 +929,28 @@ class TokenEmbeddingsHandler:
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
# random initialization of new tokens
std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
std_token_embedding = (
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data.std()
print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
(
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[self.train_ids] = (
torch.randn(
len(self.train_ids),
(
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).config.hidden_size,
)
.to(device=self.device)
.to(dtype=self.dtype)
* std_token_embedding
)
self.embeddings_settings[f"original_embeddings_{idx}"] = (
text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
)
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data.clone()
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
@@ -959,10 +968,14 @@ class TokenEmbeddingsHandler:
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
for idx, text_encoder in enumerate(self.text_encoders):
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
self.tokenizers[0]
), "Tokenizers should be the same."
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
assert (
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data.shape[0] == len(self.tokenizers[0]), (
"Tokenizers should be the same."
)
new_token_embeddings = (
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[self.train_ids]
# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
# text_encoder 1) to keep compatible with the ecosystem.
@@ -984,7 +997,9 @@ class TokenEmbeddingsHandler:
def retract_embeddings(self):
for idx, text_encoder in enumerate(self.text_encoders):
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = (
(
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[index_no_updates] = (
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
.to(device=text_encoder.device)
.to(dtype=text_encoder.dtype)
@@ -995,11 +1010,15 @@ class TokenEmbeddingsHandler:
std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
index_updates = ~index_no_updates
new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates]
new_embeddings = (
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[index_updates]
off_ratio = std_token_embedding / new_embeddings.std()
new_embeddings = new_embeddings * (off_ratio**0.1)
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
(
text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
).embeddings.token_embedding.weight.data[index_updates] = new_embeddings
class DreamBoothDataset(Dataset):
@@ -2083,8 +2102,10 @@ def main(args):
text_encoder_two.train()
# set top parameter requires_grad = True for gradient checkpointing works
if args.train_text_encoder:
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
_te_one = accelerator.unwrap_model(text_encoder_one)
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
_te_two = accelerator.unwrap_model(text_encoder_two)
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader):
if pivoted:

View File

@@ -874,10 +874,11 @@ def main(args):
token_embeds[x] = token_embeds[y]
# Freeze all parameters except for the token embeddings in text encoder
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
params_to_freeze = itertools.chain(
text_encoder.text_model.encoder.parameters(),
text_encoder.text_model.final_layer_norm.parameters(),
text_encoder.text_model.embeddings.position_embedding.parameters(),
text_module.encoder.parameters(),
text_module.final_layer_norm.parameters(),
text_module.embeddings.position_embedding.parameters(),
)
freeze_params(params_to_freeze)
########################################################

View File

@@ -1691,7 +1691,8 @@ def main(args):
if args.train_text_encoder:
text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
_te_one = unwrap_model(text_encoder_one)
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]

View File

@@ -1896,7 +1896,8 @@ def main(args):
if args.train_text_encoder:
text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
_te_one = unwrap_model(text_encoder_one)
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]

View File

@@ -1719,8 +1719,10 @@ def main(args):
text_encoder_two.train()
# set top parameter requires_grad = True for gradient checkpointing works
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
_te_one = accelerator.unwrap_model(text_encoder_one)
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
_te_two = accelerator.unwrap_model(text_encoder_two)
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]

View File

@@ -1661,8 +1661,10 @@ def main(args):
text_encoder_two.train()
# set top parameter requires_grad = True for gradient checkpointing works
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
_te_one = accelerator.unwrap_model(text_encoder_one)
(_te_one.text_model if hasattr(_te_one, "text_model") else _te_one).embeddings.requires_grad_(True)
_te_two = accelerator.unwrap_model(text_encoder_two)
(_te_two.text_model if hasattr(_te_two, "text_model") else _te_two).embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):

View File

@@ -702,9 +702,10 @@ def main():
vae.requires_grad_(False)
unet.requires_grad_(False)
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
text_module.encoder.requires_grad_(False)
text_module.final_layer_norm.requires_grad_(False)
text_module.embeddings.position_embedding.requires_grad_(False)
if args.gradient_checkpointing:
# Keep unet in train mode if we are using gradient checkpointing to save memory.

View File

@@ -717,12 +717,14 @@ def main():
unet.requires_grad_(False)
# Freeze all parameters except for the token embeddings in text encoder
text_encoder_1.text_model.encoder.requires_grad_(False)
text_encoder_1.text_model.final_layer_norm.requires_grad_(False)
text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)
text_encoder_2.text_model.encoder.requires_grad_(False)
text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
text_module_1 = text_encoder_1.text_model if hasattr(text_encoder_1, "text_model") else text_encoder_1
text_module_1.encoder.requires_grad_(False)
text_module_1.final_layer_norm.requires_grad_(False)
text_module_1.embeddings.position_embedding.requires_grad_(False)
text_module_2 = text_encoder_2.text_model if hasattr(text_encoder_2, "text_model") else text_encoder_2
text_module_2.encoder.requires_grad_(False)
text_module_2.final_layer_norm.requires_grad_(False)
text_module_2.embeddings.position_embedding.requires_grad_(False)
if args.gradient_checkpointing:
text_encoder_1.gradient_checkpointing_enable()
@@ -767,8 +769,12 @@ def main():
optimizer = optimizer_class(
# only optimize the embeddings
[
text_encoder_1.text_model.embeddings.token_embedding.weight,
text_encoder_2.text_model.embeddings.token_embedding.weight,
(
text_encoder_1.text_model if hasattr(text_encoder_1, "text_model") else text_encoder_1
).embeddings.token_embedding.weight,
(
text_encoder_2.text_model if hasattr(text_encoder_2, "text_model") else text_encoder_2
).embeddings.token_embedding.weight,
],
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),

View File

@@ -540,7 +540,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
)
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_version(">=", "0.12.3"):
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
raise RuntimeError(
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
)