Fix examples not loading LoRA adapter weights from checkpoint (#12690)

* Fix examples not loading LoRA adapter weights from checkpoint

* Updated lora saving logic with accelerate save_model_hook and load_model_hook

* Formatted the changes using ruff

* import and upcasting changed

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Ayush Sur
2025-11-28 11:56:39 +05:30
committed by GitHub
parent 01e355516b
commit 1b91856d0e

View File

@@ -37,7 +37,7 @@ from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
@@ -46,7 +46,12 @@ import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params, compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils import (
check_min_version,
convert_state_dict_to_diffusers,
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -708,6 +713,56 @@ def main():
num_workers=args.dataloader_num_workers,
)
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
unet_lora_layers_to_save = None
for model in models:
if isinstance(model, type(unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"Unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
StableDiffusionPipeline.save_lora_weights(
save_directory=output_dir,
unet_lora_layers=unet_lora_layers_to_save,
safe_serialization=True,
)
def load_model_hook(models, input_dir):
unet_ = None
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(unet))):
unet_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# returns a tuple of state dictionary and network alphas
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
# throw warning if some unexpected keys are found and continue loading
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Make sure the trainable params are in float32
if args.mixed_precision in ["fp16"]:
cast_training_params([unet_], dtype=torch.float32)
# Scheduler and math around the number of training steps.
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
@@ -732,6 +787,10 @@ def main():
unet, optimizer, train_dataloader, lr_scheduler
)
# Register the hooks for efficient saving and loading of LoRA weights
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
@@ -906,17 +965,6 @@ def main():
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
unwrapped_unet = unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(
get_peft_model_state_dict(unwrapped_unet)
)
StableDiffusionPipeline.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)
logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}