Compare commits

...

3 Commits

Author SHA1 Message Date
Dhruv Nair
3804034034 update 2024-07-05 05:41:26 +00:00
Dhruv Nair
fe403df178 update 2024-07-02 06:51:43 +00:00
Dhruv Nair
4ae0695699 update 2024-07-02 06:34:28 +00:00
2 changed files with 3 additions and 2 deletions

View File

@@ -1290,6 +1290,7 @@ def main(args):
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
@@ -1981,7 +1982,7 @@ def main(args):
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")
save_model_card(
model_id if not args.push_to_hub else repo_id,

View File

@@ -2413,7 +2413,7 @@ def main(args):
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")
save_model_card(
model_id if not args.push_to_hub else repo_id,