Compare commits

...

13 Commits

Author SHA1 Message Date
sayakpaul
fa9df34660 check 2023-12-08 17:41:03 +05:30
sayakpaul
8285e4e72d potentially fix device placement 2023-12-08 17:30:38 +05:30
sayakpaul
212afef9bf debug 2023-12-08 17:27:26 +05:30
sayakpaul
3b42e96912 better/ 2023-12-08 17:23:43 +05:30
sayakpaul
3e051abd75 device and dtype 2023-12-08 17:19:05 +05:30
sayakpaul
c580ff04d5 better handle bias 2023-12-08 17:17:06 +05:30
sayakpaul
4c38c229e1 fix: bias copy 2023-12-08 17:15:49 +05:30
sayakpaul
0bb32cc285 import 2023-12-08 17:12:16 +05:30
sayakpaul
1d228b8eb0 torch.no_grad deco 2023-12-08 17:11:48 +05:30
Sayak Paul
c05d71be04 Merge branch 'main' into fix/training-peft-installed 2023-12-05 15:19:10 +05:30
sayakpaul
86b44367e9 setting the params too 2023-12-04 18:24:35 +05:30
sayakpaul
5f6164cdc5 replace linear_cls in case peft is installed. 2023-12-04 17:53:49 +05:30
sayakpaul
e856faea4f feat: make linear_cls a class member when needed. 2023-12-04 17:30:03 +05:30
10 changed files with 42 additions and 11 deletions

View File

@@ -42,8 +42,8 @@ import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.training_utils import compute_snr, replace_linear_cls
from diffusers.utils import check_min_version, is_peft_available, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
@@ -466,6 +466,7 @@ def main():
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
@@ -480,10 +481,14 @@ def main():
weight_dtype = torch.bfloat16
# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
# unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
# Replace the `nn.Linear` layers with `LoRACompatibleLinear` layers.
if is_peft_available():
replace_linear_cls(unet)
# now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
@@ -700,10 +705,14 @@ def main():
)
# Prepare everything with our `accelerator`.
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
# unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
# unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
# )
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:

View File

@@ -33,8 +33,8 @@ if is_torch_available():
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["prior_transformer"] = ["PriorTransformer"]
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformer_2d"] = ["Transformer2DModel"]

View File

@@ -88,6 +88,7 @@ class GEGLU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
self.linear_cls = linear_cls
self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)

View File

@@ -175,11 +175,8 @@ class Attention(nn.Module):
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)
if USE_PEFT_BACKEND:
linear_cls = nn.Linear
else:
linear_cls = LoRACompatibleLinear
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention:

View File

@@ -200,6 +200,7 @@ class TimestepEmbedding(nn.Module):
):
super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_cls = linear_cls
self.linear_1 = linear_cls(in_channels, time_embed_dim)

View File

@@ -649,6 +649,7 @@ class ResnetBlock2D(nn.Module):
self.skip_time_act = skip_time_act
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_cls = linear_cls
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if groups_out is None:

View File

@@ -107,6 +107,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_cls = linear_cls
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration

View File

@@ -35,6 +35,7 @@ class TimestepBlock(nn.Module):
def __init__(self, c, c_timestep):
super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_cls = linear_cls
self.mapper = linear_cls(c_timestep, c * 2)
def forward(self, x, t):

View File

@@ -43,6 +43,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
super().__init__()
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_cls = linear_cls
self.c_r = c_r
self.projection = conv_cls(c_in, c, kernel_size=1)

View File

@@ -7,6 +7,7 @@ import numpy as np
import torch
from .models import UNet2DConditionModel
from .models.lora import LoRACompatibleLinear
from .utils import deprecate, is_transformers_available
@@ -53,6 +54,24 @@ def compute_snr(noise_scheduler, timesteps):
return snr
@torch.no_grad()
def replace_linear_cls(model):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
bias = True if hasattr(module, "bias") and getattr(module, "bias", None) is not None else False
new_linear_cls = LoRACompatibleLinear(module.in_features, module.out_features, bias=bias)
new_linear_cls.weight.copy_(module.weight.data)
new_linear_cls.weight.data.to(device=module.weight.data.device, dtype=module.weight.data.dtype)
if bias:
new_linear_cls.bias.copy_(module.bias.data)
new_linear_cls.bias.data.to(device=module.bias.data.device, dtype=module.bias.data.dtype)
setattr(model, name, new_linear_cls)
elif len(list(module.children())) > 0:
# Recursively apply the same operation to child modules
replace_linear_cls(module)
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
r"""
Returns: