mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
14 Commits
dynamic-te
...
add-peft-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f9427f0dd | ||
|
|
f6844d3cf6 | ||
|
|
daa75665cf | ||
|
|
38aece94c4 | ||
|
|
b03aa10375 | ||
|
|
2bfdcabadc | ||
|
|
a837033105 | ||
|
|
566aaab423 | ||
|
|
9b910bdc5c | ||
|
|
965b40aa17 | ||
|
|
d62076ac5f | ||
|
|
3338ce0d40 | ||
|
|
565416c2e7 | ||
|
|
6338ad5b0b |
@@ -37,6 +37,8 @@ from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from safetensors.torch import save_file
|
||||
@@ -54,10 +56,9 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.models.lora import LoRALinearLayer
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr, unet_lora_state_dict
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -67,39 +68,6 @@ check_min_version("0.25.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
attn_modules.append((name, mod))
|
||||
|
||||
return attn_modules
|
||||
|
||||
for name, module in text_encoder_attn_modules(text_encoder):
|
||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
@@ -161,8 +129,6 @@ tags:
|
||||
base_model: {base_model}
|
||||
instance_prompt: {instance_prompt}
|
||||
license: openrail++
|
||||
widget:
|
||||
- text: '{validation_prompt if validation_prompt else instance_prompt}'
|
||||
---
|
||||
"""
|
||||
|
||||
@@ -1264,54 +1230,25 @@ def main(args):
|
||||
text_encoder_two.gradient_checkpointing_enable()
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
# Set correct lora layers
|
||||
unet_lora_parameters = []
|
||||
for attn_processor_name, attn_processor in unet.attn_processors.items():
|
||||
# Parse the attention module.
|
||||
attn_module = unet
|
||||
for n in attn_processor_name.split(".")[:-1]:
|
||||
attn_module = getattr(attn_module, n)
|
||||
|
||||
# Set the `lora_layer` attribute of the attention-related matrices.
|
||||
attn_module.to_q.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_k.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_v.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_out[0].set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_out[0].in_features,
|
||||
out_features=attn_module.to_out[0].out_features,
|
||||
rank=args.rank,
|
||||
)
|
||||
)
|
||||
|
||||
# Accumulate the LoRA params to optimize.
|
||||
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
unet.add_adapter(unet_lora_config)
|
||||
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||
if args.train_text_encoder:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
|
||||
text_encoder_one, dtype=torch.float32, rank=args.rank
|
||||
)
|
||||
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
|
||||
text_encoder_two, dtype=torch.float32, rank=args.rank
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
# if we use textual inversion, we freeze all parameters except for the token embeddings
|
||||
# in text encoder
|
||||
@@ -1335,6 +1272,17 @@ def main(args):
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
for model in models:
|
||||
for param in model.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
@@ -1346,11 +1294,15 @@ def main(args):
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_lora_state_dict(model)
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1407,6 +1359,12 @@ def main(args):
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
|
||||
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
|
||||
|
||||
# If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training
|
||||
freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)
|
||||
|
||||
@@ -1997,13 +1955,17 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = unet_lora_state_dict(unet)
|
||||
unet_lora_layers = get_peft_model_state_dict(unet)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
|
||||
text_encoder_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
)
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
|
||||
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
||||
)
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
|
||||
Reference in New Issue
Block a user