Compare commits

...

6 Commits

Author SHA1 Message Date
Sayak Paul
7394b99047 Merge branch 'main' into fix-casting-training-params 2024-06-18 14:45:33 +01:00
Sayak Paul
4716a413bf Merge branch 'main' into fix-casting-training-params 2024-06-18 14:18:05 +01:00
Sayak Paul
c843fb25ec Merge branch 'main' into fix-casting-training-params 2024-06-17 20:43:54 +01:00
Sayak Paul
38d768a876 Merge branch 'main' into fix-casting-training-params 2024-06-11 13:09:41 +01:00
Sayak Paul
2c35ea66d9 Merge branch 'main' into fix-casting-training-params 2024-06-10 13:57:27 +01:00
sayakpaul
34fbd5526a fix the position of param casting when loading them 2024-06-10 13:53:59 +01:00
2 changed files with 4 additions and 4 deletions

View File

@@ -1289,8 +1289,8 @@ def main(args):
models = [unet_]
if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_])
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)

View File

@@ -1363,8 +1363,8 @@ def main(args):
models = [unet_]
if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_])
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)