mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-11 23:14:37 +08:00
Compare commits
32 Commits
custom-cod
...
fix/lora-l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f9c7b327cb | ||
|
|
255adf0466 | ||
|
|
5d04eebd1f | ||
|
|
f145d48ed7 | ||
|
|
765fef7134 | ||
|
|
8c98a187c7 | ||
|
|
ece6d89cf2 | ||
|
|
16ac1b2f4f | ||
|
|
ec9df6fc48 | ||
|
|
09618d09a6 | ||
|
|
d24e7d3ea9 | ||
|
|
57a16f35ee | ||
|
|
f4adaae5cb | ||
|
|
49a0f3ab02 | ||
|
|
24cb282d36 | ||
|
|
bcf0f4a789 | ||
|
|
fdb114618d | ||
|
|
ed333f06ae | ||
|
|
32212b6df6 | ||
|
|
9ecb271ac8 | ||
|
|
a2792cd942 | ||
|
|
c341111d69 | ||
|
|
b868e8a2fc | ||
|
|
41b9cd8787 | ||
|
|
0d08249c9f | ||
|
|
3b27b23082 | ||
|
|
e4c00bc5c2 | ||
|
|
cf132fb6b0 | ||
|
|
981ea82591 | ||
|
|
20fac7bc9d | ||
|
|
79b16373b7 | ||
|
|
ff3d380824 |
@@ -880,11 +880,16 @@ def main(args):
|
|||||||
unet_lora_layers_to_save = None
|
unet_lora_layers_to_save = None
|
||||||
text_encoder_lora_layers_to_save = None
|
text_encoder_lora_layers_to_save = None
|
||||||
|
|
||||||
|
unet_lora_config = None
|
||||||
|
text_encoder_lora_config = None
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||||
|
unet_lora_config = model.peft_config["default"]
|
||||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
||||||
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
|
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||||
|
text_encoder_lora_config = model.peft_config["default"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||||
|
|
||||||
@@ -895,6 +900,8 @@ def main(args):
|
|||||||
output_dir,
|
output_dir,
|
||||||
unet_lora_layers=unet_lora_layers_to_save,
|
unet_lora_layers=unet_lora_layers_to_save,
|
||||||
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
||||||
|
unet_lora_config=unet_lora_config,
|
||||||
|
text_encoder_lora_config=text_encoder_lora_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_model_hook(models, input_dir):
|
def load_model_hook(models, input_dir):
|
||||||
@@ -911,10 +918,12 @@ def main(args):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||||
|
|
||||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
LoraLoaderMixin.load_lora_into_unet(
|
||||||
|
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
|
||||||
|
)
|
||||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
|
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_, config=metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||||
@@ -1315,17 +1324,22 @@ def main(args):
|
|||||||
unet = unet.to(torch.float32)
|
unet = unet.to(torch.float32)
|
||||||
|
|
||||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||||
|
unet_lora_config = unet.peft_config["default"]
|
||||||
|
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||||
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
|
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
|
||||||
|
text_encoder_lora_config = text_encoder.peft_config["default"]
|
||||||
else:
|
else:
|
||||||
text_encoder_state_dict = None
|
text_encoder_state_dict = None
|
||||||
|
text_encoder_lora_config = None
|
||||||
|
|
||||||
LoraLoaderMixin.save_lora_weights(
|
LoraLoaderMixin.save_lora_weights(
|
||||||
save_directory=args.output_dir,
|
save_directory=args.output_dir,
|
||||||
unet_lora_layers=unet_lora_state_dict,
|
unet_lora_layers=unet_lora_state_dict,
|
||||||
text_encoder_lora_layers=text_encoder_state_dict,
|
text_encoder_lora_layers=text_encoder_state_dict,
|
||||||
|
unet_lora_config=unet_lora_config,
|
||||||
|
text_encoder_lora_config=text_encoder_lora_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Final inference
|
# Final inference
|
||||||
|
|||||||
@@ -1033,13 +1033,20 @@ def main(args):
|
|||||||
text_encoder_one_lora_layers_to_save = None
|
text_encoder_one_lora_layers_to_save = None
|
||||||
text_encoder_two_lora_layers_to_save = None
|
text_encoder_two_lora_layers_to_save = None
|
||||||
|
|
||||||
|
unet_lora_config = None
|
||||||
|
text_encoder_one_lora_config = None
|
||||||
|
text_encoder_two_lora_config = None
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||||
|
unet_lora_config = model.peft_config["default"]
|
||||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||||
|
text_encoder_one_lora_config = model.peft_config["default"]
|
||||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||||
|
text_encoder_two_lora_config = model.peft_config["default"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||||
|
|
||||||
@@ -1051,6 +1058,9 @@ def main(args):
|
|||||||
unet_lora_layers=unet_lora_layers_to_save,
|
unet_lora_layers=unet_lora_layers_to_save,
|
||||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||||
|
unet_lora_config=unet_lora_config,
|
||||||
|
text_encoder_lora_config=text_encoder_one_lora_config,
|
||||||
|
text_encoder_2_lora_config=text_encoder_two_lora_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_model_hook(models, input_dir):
|
def load_model_hook(models, input_dir):
|
||||||
@@ -1070,17 +1080,19 @@ def main(args):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||||
|
|
||||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
LoraLoaderMixin.load_lora_into_unet(
|
||||||
|
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
|
||||||
|
)
|
||||||
|
|
||||||
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
|
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
|
||||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
|
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
|
||||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||||
@@ -1616,21 +1628,29 @@ def main(args):
|
|||||||
unet = accelerator.unwrap_model(unet)
|
unet = accelerator.unwrap_model(unet)
|
||||||
unet = unet.to(torch.float32)
|
unet = unet.to(torch.float32)
|
||||||
unet_lora_layers = get_peft_model_state_dict(unet)
|
unet_lora_layers = get_peft_model_state_dict(unet)
|
||||||
|
unet_lora_config = unet.peft_config["default"]
|
||||||
|
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||||
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
||||||
|
text_encoder_one_lora_config = text_encoder_one.peft_config["default"]
|
||||||
|
text_encoder_two_lora_config = text_encoder_two.peft_config["default"]
|
||||||
else:
|
else:
|
||||||
text_encoder_lora_layers = None
|
text_encoder_lora_layers = None
|
||||||
text_encoder_2_lora_layers = None
|
text_encoder_2_lora_layers = None
|
||||||
|
text_encoder_one_lora_config = None
|
||||||
|
text_encoder_two_lora_config = None
|
||||||
|
|
||||||
StableDiffusionXLPipeline.save_lora_weights(
|
StableDiffusionXLPipeline.save_lora_weights(
|
||||||
save_directory=args.output_dir,
|
save_directory=args.output_dir,
|
||||||
unet_lora_layers=unet_lora_layers,
|
unet_lora_layers=unet_lora_layers,
|
||||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||||
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
||||||
|
unet_lora_config=unet_lora_config,
|
||||||
|
text_encoder_lora_config=text_encoder_one_lora_config,
|
||||||
|
text_encoder_2_lora_config=text_encoder_two_lora_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Final inference
|
# Final inference
|
||||||
|
|||||||
@@ -833,10 +833,12 @@ def main():
|
|||||||
accelerator.save_state(save_path)
|
accelerator.save_state(save_path)
|
||||||
|
|
||||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||||
|
unet_lora_config = unet.peft_config["default"]
|
||||||
|
|
||||||
StableDiffusionPipeline.save_lora_weights(
|
StableDiffusionPipeline.save_lora_weights(
|
||||||
save_directory=save_path,
|
save_directory=save_path,
|
||||||
unet_lora_layers=unet_lora_state_dict,
|
unet_lora_layers=unet_lora_state_dict,
|
||||||
|
unet_lora_config=unet_lora_config,
|
||||||
safe_serialization=True,
|
safe_serialization=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -898,10 +900,12 @@ def main():
|
|||||||
unet = unet.to(torch.float32)
|
unet = unet.to(torch.float32)
|
||||||
|
|
||||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||||
|
unet_lora_config = unet.peft_config["default"]
|
||||||
StableDiffusionPipeline.save_lora_weights(
|
StableDiffusionPipeline.save_lora_weights(
|
||||||
save_directory=args.output_dir,
|
save_directory=args.output_dir,
|
||||||
unet_lora_layers=unet_lora_state_dict,
|
unet_lora_layers=unet_lora_state_dict,
|
||||||
safe_serialization=True,
|
safe_serialization=True,
|
||||||
|
unet_lora_config=unet_lora_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
|
|||||||
@@ -682,13 +682,20 @@ def main(args):
|
|||||||
text_encoder_one_lora_layers_to_save = None
|
text_encoder_one_lora_layers_to_save = None
|
||||||
text_encoder_two_lora_layers_to_save = None
|
text_encoder_two_lora_layers_to_save = None
|
||||||
|
|
||||||
|
unet_lora_config = None
|
||||||
|
text_encoder_one_lora_config = None
|
||||||
|
text_encoder_two_lora_config = None
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||||
|
unet_lora_config = model.peft_config["default"]
|
||||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||||
|
text_encoder_one_lora_config = model.peft_config["default"]
|
||||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||||
|
text_encoder_two_lora_config = model.peft_config["default"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||||
|
|
||||||
@@ -700,6 +707,9 @@ def main(args):
|
|||||||
unet_lora_layers=unet_lora_layers_to_save,
|
unet_lora_layers=unet_lora_layers_to_save,
|
||||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||||
|
unet_lora_config=unet_lora_config,
|
||||||
|
text_encoder_lora_config=text_encoder_one_lora_config,
|
||||||
|
text_encoder_2_lora_config=text_encoder_two_lora_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_model_hook(models, input_dir):
|
def load_model_hook(models, input_dir):
|
||||||
@@ -719,17 +729,19 @@ def main(args):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||||
|
|
||||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
LoraLoaderMixin.load_lora_into_unet(
|
||||||
|
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
|
||||||
|
)
|
||||||
|
|
||||||
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
|
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
|
||||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
|
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
|
||||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||||
@@ -1194,6 +1206,7 @@ def main(args):
|
|||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
unet = accelerator.unwrap_model(unet)
|
unet = accelerator.unwrap_model(unet)
|
||||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||||
|
unet_lora_config = unet.peft_config["default"]
|
||||||
|
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||||
@@ -1201,15 +1214,23 @@ def main(args):
|
|||||||
|
|
||||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
|
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
|
||||||
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
|
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
|
||||||
|
|
||||||
|
text_encoder_one_lora_config = text_encoder_one.peft_config["default"]
|
||||||
|
text_encoder_two_lora_config = text_encoder_two.peft_config["default"]
|
||||||
else:
|
else:
|
||||||
text_encoder_lora_layers = None
|
text_encoder_lora_layers = None
|
||||||
text_encoder_2_lora_layers = None
|
text_encoder_2_lora_layers = None
|
||||||
|
text_encoder_one_lora_config = None
|
||||||
|
text_encoder_two_lora_config = None
|
||||||
|
|
||||||
StableDiffusionXLPipeline.save_lora_weights(
|
StableDiffusionXLPipeline.save_lora_weights(
|
||||||
save_directory=args.output_dir,
|
save_directory=args.output_dir,
|
||||||
unet_lora_layers=unet_lora_state_dict,
|
unet_lora_layers=unet_lora_state_dict,
|
||||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||||
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
||||||
|
unet_lora_config=unet_lora_config,
|
||||||
|
text_encoder_lora_config=text_encoder_one_lora_config,
|
||||||
|
text_encoder_2_lora_config=text_encoder_two_lora_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
del unet
|
del unet
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Callable, Dict, List, Optional, Union
|
from typing import Callable, Dict, List, Optional, Union
|
||||||
@@ -103,7 +104,7 @@ class LoraLoaderMixin:
|
|||||||
`default_{i}` where i is the total number of adapters being loaded.
|
`default_{i}` where i is the total number of adapters being loaded.
|
||||||
"""
|
"""
|
||||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||||
|
|
||||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||||
if not is_correct_format:
|
if not is_correct_format:
|
||||||
@@ -114,6 +115,7 @@ class LoraLoaderMixin:
|
|||||||
self.load_lora_into_unet(
|
self.load_lora_into_unet(
|
||||||
state_dict,
|
state_dict,
|
||||||
network_alphas=network_alphas,
|
network_alphas=network_alphas,
|
||||||
|
config=metadata,
|
||||||
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
adapter_name=adapter_name,
|
adapter_name=adapter_name,
|
||||||
@@ -125,6 +127,7 @@ class LoraLoaderMixin:
|
|||||||
text_encoder=getattr(self, self.text_encoder_name)
|
text_encoder=getattr(self, self.text_encoder_name)
|
||||||
if not hasattr(self, "text_encoder")
|
if not hasattr(self, "text_encoder")
|
||||||
else self.text_encoder,
|
else self.text_encoder,
|
||||||
|
config=metadata,
|
||||||
lora_scale=self.lora_scale,
|
lora_scale=self.lora_scale,
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
adapter_name=adapter_name,
|
adapter_name=adapter_name,
|
||||||
@@ -219,6 +222,7 @@ class LoraLoaderMixin:
|
|||||||
}
|
}
|
||||||
|
|
||||||
model_file = None
|
model_file = None
|
||||||
|
metadata = None
|
||||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||||
# Let's first try to load .safetensors weights
|
# Let's first try to load .safetensors weights
|
||||||
if (use_safetensors and weight_name is None) or (
|
if (use_safetensors and weight_name is None) or (
|
||||||
@@ -248,6 +252,8 @@ class LoraLoaderMixin:
|
|||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||||
|
with safetensors.safe_open(model_file, framework="pt", device="cpu") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
except (IOError, safetensors.SafetensorError) as e:
|
except (IOError, safetensors.SafetensorError) as e:
|
||||||
if not allow_pickle:
|
if not allow_pickle:
|
||||||
raise e
|
raise e
|
||||||
@@ -294,7 +300,7 @@ class LoraLoaderMixin:
|
|||||||
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
||||||
state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict)
|
state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict)
|
||||||
|
|
||||||
return state_dict, network_alphas
|
return state_dict, network_alphas, metadata
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _best_guess_weight_name(
|
def _best_guess_weight_name(
|
||||||
@@ -370,7 +376,7 @@ class LoraLoaderMixin:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_lora_into_unet(
|
def load_lora_into_unet(
|
||||||
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
|
cls, state_dict, network_alphas, unet, config=None, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||||
@@ -384,6 +390,8 @@ class LoraLoaderMixin:
|
|||||||
See `LoRALinearLayer` for more details.
|
See `LoRALinearLayer` for more details.
|
||||||
unet (`UNet2DConditionModel`):
|
unet (`UNet2DConditionModel`):
|
||||||
The UNet model to load the LoRA layers into.
|
The UNet model to load the LoRA layers into.
|
||||||
|
config: (`Dict`):
|
||||||
|
LoRA configuration parsed from the state dict.
|
||||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||||
@@ -443,7 +451,9 @@ class LoraLoaderMixin:
|
|||||||
if "lora_B" in key:
|
if "lora_B" in key:
|
||||||
rank[key] = val.shape[1]
|
rank[key] = val.shape[1]
|
||||||
|
|
||||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
|
if config is not None and isinstance(config, dict) and len(config) > 0:
|
||||||
|
config = json.loads(config["unet"])
|
||||||
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True)
|
||||||
lora_config = LoraConfig(**lora_config_kwargs)
|
lora_config = LoraConfig(**lora_config_kwargs)
|
||||||
|
|
||||||
# adapter_name
|
# adapter_name
|
||||||
@@ -484,6 +494,7 @@ class LoraLoaderMixin:
|
|||||||
network_alphas,
|
network_alphas,
|
||||||
text_encoder,
|
text_encoder,
|
||||||
prefix=None,
|
prefix=None,
|
||||||
|
config=None,
|
||||||
lora_scale=1.0,
|
lora_scale=1.0,
|
||||||
low_cpu_mem_usage=None,
|
low_cpu_mem_usage=None,
|
||||||
adapter_name=None,
|
adapter_name=None,
|
||||||
@@ -502,6 +513,8 @@ class LoraLoaderMixin:
|
|||||||
The text encoder model to load the LoRA layers into.
|
The text encoder model to load the LoRA layers into.
|
||||||
prefix (`str`):
|
prefix (`str`):
|
||||||
Expected prefix of the `text_encoder` in the `state_dict`.
|
Expected prefix of the `text_encoder` in the `state_dict`.
|
||||||
|
config (`Dict`):
|
||||||
|
LoRA configuration parsed from state dict.
|
||||||
lora_scale (`float`):
|
lora_scale (`float`):
|
||||||
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
||||||
lora layer.
|
lora layer.
|
||||||
@@ -575,10 +588,11 @@ class LoraLoaderMixin:
|
|||||||
if USE_PEFT_BACKEND:
|
if USE_PEFT_BACKEND:
|
||||||
from peft import LoraConfig
|
from peft import LoraConfig
|
||||||
|
|
||||||
|
if config is not None and len(config) > 0:
|
||||||
|
config = json.loads(config[prefix])
|
||||||
lora_config_kwargs = get_peft_kwargs(
|
lora_config_kwargs = get_peft_kwargs(
|
||||||
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
|
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
|
||||||
)
|
)
|
||||||
|
|
||||||
lora_config = LoraConfig(**lora_config_kwargs)
|
lora_config = LoraConfig(**lora_config_kwargs)
|
||||||
|
|
||||||
# adapter_name
|
# adapter_name
|
||||||
@@ -786,6 +800,8 @@ class LoraLoaderMixin:
|
|||||||
save_directory: Union[str, os.PathLike],
|
save_directory: Union[str, os.PathLike],
|
||||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||||
|
unet_lora_config=None,
|
||||||
|
text_encoder_lora_config=None,
|
||||||
is_main_process: bool = True,
|
is_main_process: bool = True,
|
||||||
weight_name: str = None,
|
weight_name: str = None,
|
||||||
save_function: Callable = None,
|
save_function: Callable = None,
|
||||||
@@ -813,21 +829,54 @@ class LoraLoaderMixin:
|
|||||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||||
"""
|
"""
|
||||||
|
if not USE_PEFT_BACKEND and not safe_serialization:
|
||||||
|
if unet_lora_config or text_encoder_lora_config:
|
||||||
|
raise ValueError(
|
||||||
|
"Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` is not possible. Please install `peft`."
|
||||||
|
)
|
||||||
|
elif USE_PEFT_BACKEND and safe_serialization:
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
|
if not (unet_lora_layers or text_encoder_lora_layers):
|
||||||
|
raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.")
|
||||||
|
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
def pack_weights(layers, prefix):
|
def pack_weights(layers, prefix):
|
||||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||||
|
|
||||||
return layers_state_dict
|
return layers_state_dict
|
||||||
|
|
||||||
if not (unet_lora_layers or text_encoder_lora_layers):
|
def pack_metadata(config, prefix):
|
||||||
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")
|
local_metadata = {}
|
||||||
|
if config is not None:
|
||||||
|
if isinstance(config, LoraConfig):
|
||||||
|
config = config.to_dict()
|
||||||
|
for key, value in config.items():
|
||||||
|
if isinstance(value, set):
|
||||||
|
config[key] = list(value)
|
||||||
|
|
||||||
|
config_as_string = json.dumps(config, indent=2, sort_keys=True)
|
||||||
|
local_metadata[prefix] = config_as_string
|
||||||
|
return local_metadata
|
||||||
|
|
||||||
if unet_lora_layers:
|
if unet_lora_layers:
|
||||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
prefix = "unet"
|
||||||
|
unet_state_dict = pack_weights(unet_lora_layers, prefix)
|
||||||
|
state_dict.update(unet_state_dict)
|
||||||
|
if unet_lora_config is not None:
|
||||||
|
unet_metadata = pack_metadata(unet_lora_config, prefix)
|
||||||
|
metadata.update(unet_metadata)
|
||||||
|
|
||||||
if text_encoder_lora_layers:
|
if text_encoder_lora_layers:
|
||||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
prefix = "text_encoder"
|
||||||
|
text_encoder_state_dict = pack_weights(text_encoder_lora_layers, "text_encoder")
|
||||||
|
state_dict.update(text_encoder_state_dict)
|
||||||
|
if text_encoder_lora_config is not None:
|
||||||
|
text_encoder_metadata = pack_metadata(text_encoder_lora_config, prefix)
|
||||||
|
metadata.update(text_encoder_metadata)
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
cls.write_lora_layers(
|
cls.write_lora_layers(
|
||||||
@@ -837,6 +886,7 @@ class LoraLoaderMixin:
|
|||||||
weight_name=weight_name,
|
weight_name=weight_name,
|
||||||
save_function=save_function,
|
save_function=save_function,
|
||||||
safe_serialization=safe_serialization,
|
safe_serialization=safe_serialization,
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -847,7 +897,11 @@ class LoraLoaderMixin:
|
|||||||
weight_name: str,
|
weight_name: str,
|
||||||
save_function: Callable,
|
save_function: Callable,
|
||||||
safe_serialization: bool,
|
safe_serialization: bool,
|
||||||
|
metadata=None,
|
||||||
):
|
):
|
||||||
|
if not safe_serialization and isinstance(metadata, dict) and len(metadata) > 0:
|
||||||
|
raise ValueError("Passing `metadata` is not possible when `safe_serialization` is False.")
|
||||||
|
|
||||||
if os.path.isfile(save_directory):
|
if os.path.isfile(save_directory):
|
||||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
@@ -855,8 +909,10 @@ class LoraLoaderMixin:
|
|||||||
if save_function is None:
|
if save_function is None:
|
||||||
if safe_serialization:
|
if safe_serialization:
|
||||||
|
|
||||||
def save_function(weights, filename):
|
def save_function(weights, filename, metadata):
|
||||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
if metadata is None:
|
||||||
|
metadata = {"format": "pt"}
|
||||||
|
return safetensors.torch.save_file(weights, filename, metadata=metadata)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
save_function = torch.save
|
save_function = torch.save
|
||||||
@@ -869,7 +925,10 @@ class LoraLoaderMixin:
|
|||||||
else:
|
else:
|
||||||
weight_name = LORA_WEIGHT_NAME
|
weight_name = LORA_WEIGHT_NAME
|
||||||
|
|
||||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
if save_function != torch.save:
|
||||||
|
save_function(state_dict, os.path.join(save_directory, weight_name), metadata)
|
||||||
|
else:
|
||||||
|
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
||||||
|
|
||||||
def unload_lora_weights(self):
|
def unload_lora_weights(self):
|
||||||
@@ -1301,7 +1360,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
|||||||
# pipeline.
|
# pipeline.
|
||||||
|
|
||||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||||
state_dict, network_alphas = self.lora_state_dict(
|
state_dict, network_alphas, metadata = self.lora_state_dict(
|
||||||
pretrained_model_name_or_path_or_dict,
|
pretrained_model_name_or_path_or_dict,
|
||||||
unet_config=self.unet.config,
|
unet_config=self.unet.config,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -1311,7 +1370,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
|||||||
raise ValueError("Invalid LoRA checkpoint.")
|
raise ValueError("Invalid LoRA checkpoint.")
|
||||||
|
|
||||||
self.load_lora_into_unet(
|
self.load_lora_into_unet(
|
||||||
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
|
state_dict,
|
||||||
|
network_alphas=network_alphas,
|
||||||
|
unet=self.unet,
|
||||||
|
config=metadata,
|
||||||
|
adapter_name=adapter_name,
|
||||||
|
_pipeline=self,
|
||||||
)
|
)
|
||||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||||
if len(text_encoder_state_dict) > 0:
|
if len(text_encoder_state_dict) > 0:
|
||||||
@@ -1319,6 +1383,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
|||||||
text_encoder_state_dict,
|
text_encoder_state_dict,
|
||||||
network_alphas=network_alphas,
|
network_alphas=network_alphas,
|
||||||
text_encoder=self.text_encoder,
|
text_encoder=self.text_encoder,
|
||||||
|
config=metadata,
|
||||||
prefix="text_encoder",
|
prefix="text_encoder",
|
||||||
lora_scale=self.lora_scale,
|
lora_scale=self.lora_scale,
|
||||||
adapter_name=adapter_name,
|
adapter_name=adapter_name,
|
||||||
@@ -1331,6 +1396,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
|||||||
text_encoder_2_state_dict,
|
text_encoder_2_state_dict,
|
||||||
network_alphas=network_alphas,
|
network_alphas=network_alphas,
|
||||||
text_encoder=self.text_encoder_2,
|
text_encoder=self.text_encoder_2,
|
||||||
|
config=metadata,
|
||||||
prefix="text_encoder_2",
|
prefix="text_encoder_2",
|
||||||
lora_scale=self.lora_scale,
|
lora_scale=self.lora_scale,
|
||||||
adapter_name=adapter_name,
|
adapter_name=adapter_name,
|
||||||
@@ -1344,6 +1410,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
|||||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
|
unet_lora_config=None,
|
||||||
|
text_encoder_lora_config=None,
|
||||||
|
text_encoder_2_lora_config=None,
|
||||||
is_main_process: bool = True,
|
is_main_process: bool = True,
|
||||||
weight_name: str = None,
|
weight_name: str = None,
|
||||||
save_function: Callable = None,
|
save_function: Callable = None,
|
||||||
@@ -1371,24 +1440,63 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
|||||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||||
"""
|
"""
|
||||||
state_dict = {}
|
if not USE_PEFT_BACKEND and not safe_serialization:
|
||||||
|
if unet_lora_config or text_encoder_lora_config or text_encoder_2_lora_config:
|
||||||
def pack_weights(layers, prefix):
|
raise ValueError(
|
||||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
"Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` or `text_encoder_2_lora_config` is not possible. Please install `peft`."
|
||||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
)
|
||||||
return layers_state_dict
|
elif USE_PEFT_BACKEND and safe_serialization:
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
|
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
state_dict = {}
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
def pack_weights(layers, prefix):
|
||||||
|
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||||
|
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||||
|
|
||||||
|
return layers_state_dict
|
||||||
|
|
||||||
|
def pack_metadata(config, prefix):
|
||||||
|
local_metadata = {}
|
||||||
|
if config is not None:
|
||||||
|
if isinstance(config, LoraConfig):
|
||||||
|
config = config.to_dict()
|
||||||
|
for key, value in config.items():
|
||||||
|
if isinstance(value, set):
|
||||||
|
config[key] = list(value)
|
||||||
|
|
||||||
|
config_as_string = json.dumps(config, indent=2, sort_keys=True)
|
||||||
|
local_metadata[prefix] = config_as_string
|
||||||
|
return local_metadata
|
||||||
|
|
||||||
if unet_lora_layers:
|
if unet_lora_layers:
|
||||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
prefix = "unet"
|
||||||
|
unet_state_dict = pack_weights(unet_lora_layers, prefix)
|
||||||
|
state_dict.update(unet_state_dict)
|
||||||
|
if unet_lora_config is not None:
|
||||||
|
unet_metadata = pack_metadata(unet_lora_config, prefix)
|
||||||
|
metadata.update(unet_metadata)
|
||||||
|
|
||||||
if text_encoder_lora_layers and text_encoder_2_lora_layers:
|
if text_encoder_lora_layers and text_encoder_2_lora_layers:
|
||||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
prefix = "text_encoder"
|
||||||
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
text_encoder_state_dict = pack_weights(text_encoder_lora_layers, "text_encoder")
|
||||||
|
state_dict.update(text_encoder_state_dict)
|
||||||
|
if text_encoder_lora_config is not None:
|
||||||
|
text_encoder_metadata = pack_metadata(text_encoder_lora_config, prefix)
|
||||||
|
metadata.update(text_encoder_metadata)
|
||||||
|
|
||||||
|
prefix = "text_encoder_2"
|
||||||
|
text_encoder_2_state_dict = pack_weights(text_encoder_2_lora_layers, prefix)
|
||||||
|
state_dict.update(text_encoder_2_state_dict)
|
||||||
|
if text_encoder_2_lora_config is not None:
|
||||||
|
text_encoder_2_metadata = pack_metadata(text_encoder_2_lora_config, prefix)
|
||||||
|
metadata.update(text_encoder_2_metadata)
|
||||||
|
|
||||||
cls.write_lora_layers(
|
cls.write_lora_layers(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
@@ -1397,6 +1505,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
|||||||
weight_name=weight_name,
|
weight_name=weight_name,
|
||||||
save_function=save_function,
|
save_function=save_function,
|
||||||
safe_serialization=safe_serialization,
|
safe_serialization=safe_serialization,
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _remove_text_encoder_monkey_patch(self):
|
def _remove_text_encoder_monkey_patch(self):
|
||||||
|
|||||||
@@ -138,11 +138,17 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
|
|||||||
module.set_scale(adapter_name, 1.0)
|
module.set_scale(adapter_name, 1.0)
|
||||||
|
|
||||||
|
|
||||||
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
|
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, is_unet=True):
|
||||||
rank_pattern = {}
|
rank_pattern = {}
|
||||||
alpha_pattern = {}
|
alpha_pattern = {}
|
||||||
r = lora_alpha = list(rank_dict.values())[0]
|
r = lora_alpha = list(rank_dict.values())[0]
|
||||||
|
|
||||||
|
# Try to retrive config.
|
||||||
|
alpha_retrieved = False
|
||||||
|
if config is not None:
|
||||||
|
lora_alpha = config["lora_alpha"]
|
||||||
|
alpha_retrieved = True
|
||||||
|
|
||||||
if len(set(rank_dict.values())) > 1:
|
if len(set(rank_dict.values())) > 1:
|
||||||
# get the rank occuring the most number of times
|
# get the rank occuring the most number of times
|
||||||
r = collections.Counter(rank_dict.values()).most_common()[0][0]
|
r = collections.Counter(rank_dict.values()).most_common()[0][0]
|
||||||
@@ -154,7 +160,8 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
|
|||||||
if network_alpha_dict is not None and len(network_alpha_dict) > 0:
|
if network_alpha_dict is not None and len(network_alpha_dict) > 0:
|
||||||
if len(set(network_alpha_dict.values())) > 1:
|
if len(set(network_alpha_dict.values())) > 1:
|
||||||
# get the alpha occuring the most number of times
|
# get the alpha occuring the most number of times
|
||||||
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
|
if not alpha_retrieved:
|
||||||
|
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
|
||||||
|
|
||||||
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
|
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
|
||||||
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
|
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
|
||||||
@@ -165,7 +172,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
|
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
|
||||||
else:
|
elif not alpha_retrieved:
|
||||||
lora_alpha = set(network_alpha_dict.values()).pop()
|
lora_alpha = set(network_alpha_dict.values()).pop()
|
||||||
|
|
||||||
# layer names without the Diffusers specific
|
# layer names without the Diffusers specific
|
||||||
|
|||||||
@@ -107,8 +107,9 @@ class PeftLoraLoaderMixinTests:
|
|||||||
unet_kwargs = None
|
unet_kwargs = None
|
||||||
vae_kwargs = None
|
vae_kwargs = None
|
||||||
|
|
||||||
def get_dummy_components(self, scheduler_cls=None):
|
def get_dummy_components(self, scheduler_cls=None, lora_alpha=None):
|
||||||
scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler
|
scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler
|
||||||
|
lora_alpha = 4 if lora_alpha is None else lora_alpha
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
unet = UNet2DConditionModel(**self.unet_kwargs)
|
unet = UNet2DConditionModel(**self.unet_kwargs)
|
||||||
@@ -123,11 +124,14 @@ class PeftLoraLoaderMixinTests:
|
|||||||
tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
|
tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
|
||||||
|
|
||||||
text_lora_config = LoraConfig(
|
text_lora_config = LoraConfig(
|
||||||
r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False
|
r=4,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||||
|
init_lora_weights=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
unet_lora_config = LoraConfig(
|
unet_lora_config = LoraConfig(
|
||||||
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
|
r=4, lora_alpha=lora_alpha, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
|
||||||
)
|
)
|
||||||
|
|
||||||
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
|
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
|
||||||
@@ -714,6 +718,68 @@ class PeftLoraLoaderMixinTests:
|
|||||||
"Fused lora should change the output",
|
"Fused lora should change the output",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_if_lora_alpha_is_correctly_parsed(self):
|
||||||
|
lora_alpha = 8
|
||||||
|
|
||||||
|
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe = pipe.to(self.torch_device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||||
|
|
||||||
|
pipe.unet.add_adapter(unet_lora_config)
|
||||||
|
pipe.text_encoder.add_adapter(text_lora_config)
|
||||||
|
if self.has_two_text_encoders:
|
||||||
|
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||||
|
|
||||||
|
# Inference works?
|
||||||
|
_ = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
unet_state_dict = get_peft_model_state_dict(pipe.unet)
|
||||||
|
unet_lora_config = pipe.unet.peft_config["default"]
|
||||||
|
|
||||||
|
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
|
||||||
|
text_encoder_lora_config = pipe.text_encoder.peft_config["default"]
|
||||||
|
|
||||||
|
if self.has_two_text_encoders:
|
||||||
|
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
|
||||||
|
text_encoder_2_lora_config = pipe.text_encoder_2.peft_config["default"]
|
||||||
|
|
||||||
|
self.pipeline_class.save_lora_weights(
|
||||||
|
save_directory=tmpdirname,
|
||||||
|
unet_lora_layers=unet_state_dict,
|
||||||
|
text_encoder_lora_layers=text_encoder_state_dict,
|
||||||
|
text_encoder_2_lora_layers=text_encoder_2_state_dict,
|
||||||
|
unet_lora_config=unet_lora_config,
|
||||||
|
text_encoder_lora_config=text_encoder_lora_config,
|
||||||
|
text_encoder_2_lora_config=text_encoder_2_lora_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.pipeline_class.save_lora_weights(
|
||||||
|
save_directory=tmpdirname,
|
||||||
|
unet_lora_layers=unet_state_dict,
|
||||||
|
text_encoder_lora_layers=text_encoder_state_dict,
|
||||||
|
unet_lora_config=unet_lora_config,
|
||||||
|
text_encoder_lora_config=text_encoder_lora_config,
|
||||||
|
)
|
||||||
|
loaded_pipe = self.pipeline_class(**components)
|
||||||
|
loaded_pipe.load_lora_weights(tmpdirname)
|
||||||
|
|
||||||
|
# Inference works?
|
||||||
|
_ = loaded_pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||||
|
|
||||||
|
assert (
|
||||||
|
loaded_pipe.unet.peft_config["default"].lora_alpha == lora_alpha
|
||||||
|
), "LoRA alpha not correctly loaded for UNet."
|
||||||
|
assert (
|
||||||
|
loaded_pipe.text_encoder.peft_config["default"].lora_alpha == lora_alpha
|
||||||
|
), "LoRA alpha not correctly loaded for text encoder."
|
||||||
|
if self.has_two_text_encoders:
|
||||||
|
assert (
|
||||||
|
loaded_pipe.text_encoder_2.peft_config["default"].lora_alpha == lora_alpha
|
||||||
|
), "LoRA alpha not correctly loaded for text encoder 2."
|
||||||
|
|
||||||
def test_simple_inference_with_text_unet_lora_unfused(self):
|
def test_simple_inference_with_text_unet_lora_unfused(self):
|
||||||
"""
|
"""
|
||||||
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
|
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
|
||||||
|
|||||||
Reference in New Issue
Block a user