mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
Fix examples not loading LoRA adapter weights from checkpoint (#12690)
* Fix examples not loading LoRA adapter weights from checkpoint * Updated lora saving logic with accelerate save_model_hook and load_model_hook * Formatted the changes using ruff * import and upcasting changed --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -37,7 +37,7 @@ from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
@@ -46,7 +46,12 @@ import diffusers
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import cast_training_params, compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_unet_state_dict_to_peft,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
@@ -708,6 +713,56 @@ def main():
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
unet_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"Unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
unet_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
if isinstance(model, type(unwrap_model(unet))):
|
||||
unet_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# returns a tuple of state dictionary and network alphas
|
||||
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
# throw warning if some unexpected keys are found and continue loading
|
||||
if unexpected_keys:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
# Make sure the trainable params are in float32
|
||||
if args.mixed_precision in ["fp16"]:
|
||||
cast_training_params([unet_], dtype=torch.float32)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
@@ -732,6 +787,10 @@ def main():
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# Register the hooks for efficient saving and loading of LoRA weights
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
@@ -906,17 +965,6 @@ def main():
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
|
||||
unwrapped_unet = unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(unwrapped_unet)
|
||||
)
|
||||
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=save_path,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
|
||||
Reference in New Issue
Block a user