mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-18 18:34:37 +08:00
Compare commits
13 Commits
modular-st
...
fix/traini
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fa9df34660 | ||
|
|
8285e4e72d | ||
|
|
212afef9bf | ||
|
|
3b42e96912 | ||
|
|
3e051abd75 | ||
|
|
c580ff04d5 | ||
|
|
4c38c229e1 | ||
|
|
0bb32cc285 | ||
|
|
1d228b8eb0 | ||
|
|
c05d71be04 | ||
|
|
86b44367e9 | ||
|
|
5f6164cdc5 | ||
|
|
e856faea4f |
@@ -42,8 +42,8 @@ import diffusers
|
|||||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||||
from diffusers.models.lora import LoRALinearLayer
|
from diffusers.models.lora import LoRALinearLayer
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.training_utils import compute_snr
|
from diffusers.training_utils import compute_snr, replace_linear_cls
|
||||||
from diffusers.utils import check_min_version, is_wandb_available
|
from diffusers.utils import check_min_version, is_peft_available, is_wandb_available
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
|
|
||||||
|
|
||||||
@@ -466,6 +466,7 @@ def main():
|
|||||||
unet = UNet2DConditionModel.from_pretrained(
|
unet = UNet2DConditionModel.from_pretrained(
|
||||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
||||||
)
|
)
|
||||||
|
|
||||||
# freeze parameters of models to save more memory
|
# freeze parameters of models to save more memory
|
||||||
unet.requires_grad_(False)
|
unet.requires_grad_(False)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
@@ -480,10 +481,14 @@ def main():
|
|||||||
weight_dtype = torch.bfloat16
|
weight_dtype = torch.bfloat16
|
||||||
|
|
||||||
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
||||||
unet.to(accelerator.device, dtype=weight_dtype)
|
# unet.to(accelerator.device, dtype=weight_dtype)
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
# Replace the `nn.Linear` layers with `LoRACompatibleLinear` layers.
|
||||||
|
if is_peft_available():
|
||||||
|
replace_linear_cls(unet)
|
||||||
|
|
||||||
# now we will add new LoRA weights to the attention layers
|
# now we will add new LoRA weights to the attention layers
|
||||||
# It's important to realize here how many attention weights will be added and of which sizes
|
# It's important to realize here how many attention weights will be added and of which sizes
|
||||||
# The sizes of the attention layers consist only of two different variables:
|
# The sizes of the attention layers consist only of two different variables:
|
||||||
@@ -700,10 +705,14 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Prepare everything with our `accelerator`.
|
# Prepare everything with our `accelerator`.
|
||||||
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
# unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
|
# unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
|
||||||
|
# )
|
||||||
|
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
unet, optimizer, train_dataloader, lr_scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
if overrode_max_train_steps:
|
if overrode_max_train_steps:
|
||||||
|
|||||||
@@ -33,8 +33,8 @@ if is_torch_available():
|
|||||||
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||||
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
|
||||||
_import_structure["embeddings"] = ["ImageProjection"]
|
_import_structure["embeddings"] = ["ImageProjection"]
|
||||||
|
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||||
_import_structure["prior_transformer"] = ["PriorTransformer"]
|
_import_structure["prior_transformer"] = ["PriorTransformer"]
|
||||||
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
|
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
|
||||||
_import_structure["transformer_2d"] = ["Transformer2DModel"]
|
_import_structure["transformer_2d"] = ["Transformer2DModel"]
|
||||||
|
|||||||
@@ -88,6 +88,7 @@ class GEGLU(nn.Module):
|
|||||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
||||||
|
self.linear_cls = linear_cls
|
||||||
|
|
||||||
self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)
|
self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)
|
||||||
|
|
||||||
|
|||||||
@@ -175,11 +175,8 @@ class Attention(nn.Module):
|
|||||||
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if USE_PEFT_BACKEND:
|
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||||
linear_cls = nn.Linear
|
self.linear_cls = linear_cls
|
||||||
else:
|
|
||||||
linear_cls = LoRACompatibleLinear
|
|
||||||
|
|
||||||
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
||||||
|
|
||||||
if not self.only_cross_attention:
|
if not self.only_cross_attention:
|
||||||
|
|||||||
@@ -200,6 +200,7 @@ class TimestepEmbedding(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||||
|
self.linear_cls = linear_cls
|
||||||
|
|
||||||
self.linear_1 = linear_cls(in_channels, time_embed_dim)
|
self.linear_1 = linear_cls(in_channels, time_embed_dim)
|
||||||
|
|
||||||
|
|||||||
@@ -649,6 +649,7 @@ class ResnetBlock2D(nn.Module):
|
|||||||
self.skip_time_act = skip_time_act
|
self.skip_time_act = skip_time_act
|
||||||
|
|
||||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||||
|
self.linear_cls = linear_cls
|
||||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||||
|
|
||||||
if groups_out is None:
|
if groups_out is None:
|
||||||
|
|||||||
@@ -107,6 +107,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||||
|
self.linear_cls = linear_cls
|
||||||
|
|
||||||
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||||
# Define whether input is continuous or discrete depending on configuration
|
# Define whether input is continuous or discrete depending on configuration
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class TimestepBlock(nn.Module):
|
|||||||
def __init__(self, c, c_timestep):
|
def __init__(self, c, c_timestep):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||||
|
self.linear_cls = linear_cls
|
||||||
self.mapper = linear_cls(c_timestep, c * 2)
|
self.mapper = linear_cls(c_timestep, c * 2)
|
||||||
|
|
||||||
def forward(self, x, t):
|
def forward(self, x, t):
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||||
|
self.linear_cls = linear_cls
|
||||||
|
|
||||||
self.c_r = c_r
|
self.c_r = c_r
|
||||||
self.projection = conv_cls(c_in, c, kernel_size=1)
|
self.projection = conv_cls(c_in, c, kernel_size=1)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .models import UNet2DConditionModel
|
from .models import UNet2DConditionModel
|
||||||
|
from .models.lora import LoRACompatibleLinear
|
||||||
from .utils import deprecate, is_transformers_available
|
from .utils import deprecate, is_transformers_available
|
||||||
|
|
||||||
|
|
||||||
@@ -53,6 +54,24 @@ def compute_snr(noise_scheduler, timesteps):
|
|||||||
return snr
|
return snr
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def replace_linear_cls(model):
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if isinstance(module, torch.nn.Linear):
|
||||||
|
bias = True if hasattr(module, "bias") and getattr(module, "bias", None) is not None else False
|
||||||
|
new_linear_cls = LoRACompatibleLinear(module.in_features, module.out_features, bias=bias)
|
||||||
|
new_linear_cls.weight.copy_(module.weight.data)
|
||||||
|
new_linear_cls.weight.data.to(device=module.weight.data.device, dtype=module.weight.data.dtype)
|
||||||
|
if bias:
|
||||||
|
new_linear_cls.bias.copy_(module.bias.data)
|
||||||
|
new_linear_cls.bias.data.to(device=module.bias.data.device, dtype=module.bias.data.dtype)
|
||||||
|
setattr(model, name, new_linear_cls)
|
||||||
|
|
||||||
|
elif len(list(module.children())) > 0:
|
||||||
|
# Recursively apply the same operation to child modules
|
||||||
|
replace_linear_cls(module)
|
||||||
|
|
||||||
|
|
||||||
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
|
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
Reference in New Issue
Block a user