fix textual inversion

This commit is contained in:
sayakpaul
2026-04-10 19:53:06 +05:30
parent 4548e68e80
commit d4386f4231
2 changed files with 18 additions and 11 deletions

View File

@@ -702,9 +702,10 @@ def main():
vae.requires_grad_(False)
unet.requires_grad_(False)
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
text_module.encoder.requires_grad_(False)
text_module.final_layer_norm.requires_grad_(False)
text_module.embeddings.position_embedding.requires_grad_(False)
if args.gradient_checkpointing:
# Keep unet in train mode if we are using gradient checkpointing to save memory.

View File

@@ -717,12 +717,14 @@ def main():
unet.requires_grad_(False)
# Freeze all parameters except for the token embeddings in text encoder
text_encoder_1.text_model.encoder.requires_grad_(False)
text_encoder_1.text_model.final_layer_norm.requires_grad_(False)
text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)
text_encoder_2.text_model.encoder.requires_grad_(False)
text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
text_module_1 = text_encoder_1.text_model if hasattr(text_encoder_1, "text_model") else text_encoder_1
text_module_1.encoder.requires_grad_(False)
text_module_1.final_layer_norm.requires_grad_(False)
text_module_1.embeddings.position_embedding.requires_grad_(False)
text_module_2 = text_encoder_2.text_model if hasattr(text_encoder_2, "text_model") else text_encoder_2
text_module_2.encoder.requires_grad_(False)
text_module_2.final_layer_norm.requires_grad_(False)
text_module_2.embeddings.position_embedding.requires_grad_(False)
if args.gradient_checkpointing:
text_encoder_1.gradient_checkpointing_enable()
@@ -767,8 +769,12 @@ def main():
optimizer = optimizer_class(
# only optimize the embeddings
[
text_encoder_1.text_model.embeddings.token_embedding.weight,
text_encoder_2.text_model.embeddings.token_embedding.weight,
(
text_encoder_1.text_model if hasattr(text_encoder_1, "text_model") else text_encoder_1
).embeddings.token_embedding.weight,
(
text_encoder_2.text_model if hasattr(text_encoder_2, "text_model") else text_encoder_2
).embeddings.token_embedding.weight,
],
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),