Compare commits

...

32 Commits

Author SHA1 Message Date
Sayak Paul
f9c7b327cb Merge branch 'main' into fix/lora-loading 2023-12-16 08:44:45 +05:30
Sayak Paul
255adf0466 Merge branch 'main' into fix/lora-loading 2023-12-15 22:43:50 +05:30
sayakpaul
5d04eebd1f fix attribute access. 2023-12-15 18:32:57 +05:30
sayakpaul
f145d48ed7 add test 2023-12-15 18:23:07 +05:30
sayakpaul
765fef7134 add: doc strings. 2023-12-15 17:58:26 +05:30
sayakpaul
8c98a187c7 propagate to sdxl t2i lora fine-tuning 2023-12-15 17:53:54 +05:30
sayakpaul
ece6d89cf2 propagate to sd t2i lora fine-tuning 2023-12-15 17:49:42 +05:30
sayakpaul
16ac1b2f4f propagate changes to sd dreambooth lora. 2023-12-15 17:46:48 +05:30
sayakpaul
ec9df6fc48 simplify condition. 2023-12-15 17:14:39 +05:30
sayakpaul
09618d09a6 remove print 2023-12-15 14:24:28 +05:30
sayakpaul
d24e7d3ea9 debug 2023-12-15 14:14:33 +05:30
sayakpaul
57a16f35ee fix: import error 2023-12-15 12:59:32 +05:30
sayakpaul
f4adaae5cb move config related stuff in a separate utility. 2023-12-15 12:49:01 +05:30
Sayak Paul
49a0f3ab02 Merge branch 'main' into fix/lora-loading 2023-12-15 12:33:49 +05:30
sayakpaul
24cb282d36 fix 2023-12-11 21:38:02 +05:30
sayakpaul
bcf0f4a789 fix 2023-12-11 18:44:05 +05:30
sayakpaul
fdb114618d Empty-Commit
Co-authored-by: pacman100 <13534540+pacman100@users.noreply.github.com>
2023-12-11 18:33:00 +05:30
sayakpaul
ed333f06ae remove print 2023-12-11 18:31:52 +05:30
sayakpaul
32212b6df6 json unwrap 2023-12-11 18:31:13 +05:30
sayakpaul
9ecb271ac8 unwrap 2023-12-11 18:24:15 +05:30
sayakpaul
a2792cd942 unwrap 2023-12-11 18:24:09 +05:30
sayakpaul
c341111d69 ifx 2023-12-11 18:22:20 +05:30
sayakpaul
b868e8a2fc ifx 2023-12-11 18:20:20 +05:30
sayakpaul
41b9cd8787 fix? 2023-12-11 18:17:22 +05:30
sayakpaul
0d08249c9f ifx? 2023-12-11 18:16:35 +05:30
sayakpaul
3b27b23082 dehug 2023-12-11 18:05:35 +05:30
sayakpaul
e4c00bc5c2 debug 2023-12-11 18:03:21 +05:30
sayakpaul
cf132fb6b0 debug 2023-12-11 17:57:51 +05:30
sayakpaul
981ea82591 assertion 2023-12-11 17:56:31 +05:30
sayakpaul
20fac7bc9d better conditioning 2023-12-11 17:54:07 +05:30
sayakpaul
79b16373b7 fix 2023-12-11 17:50:00 +05:30
sayakpaul
ff3d380824 fix: parse lora_alpha correctly 2023-12-11 17:15:05 +05:30
7 changed files with 282 additions and 41 deletions

View File

@@ -880,11 +880,16 @@ def main(args):
unet_lora_layers_to_save = None
text_encoder_lora_layers_to_save = None
unet_lora_config = None
text_encoder_lora_config = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_lora_config = model.peft_config["default"]
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -895,6 +900,8 @@ def main(args):
output_dir,
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_lora_config,
)
def load_model_hook(models, input_dir):
@@ -911,10 +918,12 @@ def main(args):
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_, config=metadata
)
accelerator.register_save_state_pre_hook(save_model_hook)
@@ -1315,17 +1324,22 @@ def main(args):
unet = unet.to(torch.float32)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]
if args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
text_encoder_lora_config = text_encoder.peft_config["default"]
else:
text_encoder_state_dict = None
text_encoder_lora_config = None
LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
text_encoder_lora_layers=text_encoder_state_dict,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_lora_config,
)
# Final inference

View File

@@ -1033,13 +1033,20 @@ def main(args):
text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None
unet_lora_config = None
text_encoder_one_lora_config = None
text_encoder_two_lora_config = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_one_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_two_lora_config = model.peft_config["default"]
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1051,6 +1058,9 @@ def main(args):
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_one_lora_config,
text_encoder_2_lora_config=text_encoder_two_lora_config,
)
def load_model_hook(models, input_dir):
@@ -1070,17 +1080,19 @@ def main(args):
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
)
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata
)
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata
)
accelerator.register_save_state_pre_hook(save_model_hook)
@@ -1616,21 +1628,29 @@ def main(args):
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]
if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
text_encoder_one_lora_config = text_encoder_one.peft_config["default"]
text_encoder_two_lora_config = text_encoder_two.peft_config["default"]
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
text_encoder_one_lora_config = None
text_encoder_two_lora_config = None
StableDiffusionXLPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_one_lora_config,
text_encoder_2_lora_config=text_encoder_two_lora_config,
)
# Final inference

View File

@@ -833,10 +833,12 @@ def main():
accelerator.save_state(save_path)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]
StableDiffusionPipeline.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_state_dict,
unet_lora_config=unet_lora_config,
safe_serialization=True,
)
@@ -898,10 +900,12 @@ def main():
unet = unet.to(torch.float32)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]
StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
unet_lora_config=unet_lora_config,
)
if args.push_to_hub:

View File

@@ -682,13 +682,20 @@ def main(args):
text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None
unet_lora_config = None
text_encoder_one_lora_config = None
text_encoder_two_lora_config = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_one_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_two_lora_config = model.peft_config["default"]
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -700,6 +707,9 @@ def main(args):
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_one_lora_config,
text_encoder_2_lora_config=text_encoder_two_lora_config,
)
def load_model_hook(models, input_dir):
@@ -719,17 +729,19 @@ def main(args):
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
)
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata
)
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata
)
accelerator.register_save_state_pre_hook(save_model_hook)
@@ -1194,6 +1206,7 @@ def main(args):
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]
if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
@@ -1201,15 +1214,23 @@ def main(args):
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
text_encoder_one_lora_config = text_encoder_one.peft_config["default"]
text_encoder_two_lora_config = text_encoder_two.peft_config["default"]
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
text_encoder_one_lora_config = None
text_encoder_two_lora_config = None
StableDiffusionXLPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_one_lora_config,
text_encoder_2_lora_config=text_encoder_two_lora_config,
)
del unet

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from contextlib import nullcontext
from typing import Callable, Dict, List, Optional, Union
@@ -103,7 +104,7 @@ class LoraLoaderMixin:
`default_{i}` where i is the total number of adapters being loaded.
"""
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
@@ -114,6 +115,7 @@ class LoraLoaderMixin:
self.load_lora_into_unet(
state_dict,
network_alphas=network_alphas,
config=metadata,
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
@@ -125,6 +127,7 @@ class LoraLoaderMixin:
text_encoder=getattr(self, self.text_encoder_name)
if not hasattr(self, "text_encoder")
else self.text_encoder,
config=metadata,
lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
@@ -219,6 +222,7 @@ class LoraLoaderMixin:
}
model_file = None
metadata = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
@@ -248,6 +252,8 @@ class LoraLoaderMixin:
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
with safetensors.safe_open(model_file, framework="pt", device="cpu") as f:
metadata = f.metadata()
except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
@@ -294,7 +300,7 @@ class LoraLoaderMixin:
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict)
return state_dict, network_alphas
return state_dict, network_alphas, metadata
@classmethod
def _best_guess_weight_name(
@@ -370,7 +376,7 @@ class LoraLoaderMixin:
@classmethod
def load_lora_into_unet(
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
cls, state_dict, network_alphas, unet, config=None, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -384,6 +390,8 @@ class LoraLoaderMixin:
See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into.
config: (`Dict`):
LoRA configuration parsed from the state dict.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
@@ -443,7 +451,9 @@ class LoraLoaderMixin:
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
if config is not None and isinstance(config, dict) and len(config) > 0:
config = json.loads(config["unet"])
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True)
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
@@ -484,6 +494,7 @@ class LoraLoaderMixin:
network_alphas,
text_encoder,
prefix=None,
config=None,
lora_scale=1.0,
low_cpu_mem_usage=None,
adapter_name=None,
@@ -502,6 +513,8 @@ class LoraLoaderMixin:
The text encoder model to load the LoRA layers into.
prefix (`str`):
Expected prefix of the `text_encoder` in the `state_dict`.
config (`Dict`):
LoRA configuration parsed from state dict.
lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer.
@@ -575,10 +588,11 @@ class LoraLoaderMixin:
if USE_PEFT_BACKEND:
from peft import LoraConfig
if config is not None and len(config) > 0:
config = json.loads(config[prefix])
lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
)
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
@@ -786,6 +800,8 @@ class LoraLoaderMixin:
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
unet_lora_config=None,
text_encoder_lora_config=None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -813,21 +829,54 @@ class LoraLoaderMixin:
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
if not USE_PEFT_BACKEND and not safe_serialization:
if unet_lora_config or text_encoder_lora_config:
raise ValueError(
"Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` is not possible. Please install `peft`."
)
elif USE_PEFT_BACKEND and safe_serialization:
from peft import LoraConfig
if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.")
state_dict = {}
metadata = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")
def pack_metadata(config, prefix):
local_metadata = {}
if config is not None:
if isinstance(config, LoraConfig):
config = config.to_dict()
for key, value in config.items():
if isinstance(value, set):
config[key] = list(value)
config_as_string = json.dumps(config, indent=2, sort_keys=True)
local_metadata[prefix] = config_as_string
return local_metadata
if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))
prefix = "unet"
unet_state_dict = pack_weights(unet_lora_layers, prefix)
state_dict.update(unet_state_dict)
if unet_lora_config is not None:
unet_metadata = pack_metadata(unet_lora_config, prefix)
metadata.update(unet_metadata)
if text_encoder_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
prefix = "text_encoder"
text_encoder_state_dict = pack_weights(text_encoder_lora_layers, "text_encoder")
state_dict.update(text_encoder_state_dict)
if text_encoder_lora_config is not None:
text_encoder_metadata = pack_metadata(text_encoder_lora_config, prefix)
metadata.update(text_encoder_metadata)
# Save the model
cls.write_lora_layers(
@@ -837,6 +886,7 @@ class LoraLoaderMixin:
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
metadata=metadata,
)
@staticmethod
@@ -847,7 +897,11 @@ class LoraLoaderMixin:
weight_name: str,
save_function: Callable,
safe_serialization: bool,
metadata=None,
):
if not safe_serialization and isinstance(metadata, dict) and len(metadata) > 0:
raise ValueError("Passing `metadata` is not possible when `safe_serialization` is False.")
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
@@ -855,8 +909,10 @@ class LoraLoaderMixin:
if save_function is None:
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
def save_function(weights, filename, metadata):
if metadata is None:
metadata = {"format": "pt"}
return safetensors.torch.save_file(weights, filename, metadata=metadata)
else:
save_function = torch.save
@@ -869,7 +925,10 @@ class LoraLoaderMixin:
else:
weight_name = LORA_WEIGHT_NAME
save_function(state_dict, os.path.join(save_directory, weight_name))
if save_function != torch.save:
save_function(state_dict, os.path.join(save_directory, weight_name), metadata)
else:
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
def unload_lora_weights(self):
@@ -1301,7 +1360,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
# pipeline.
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(
state_dict, network_alphas, metadata = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
@@ -1311,7 +1370,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet(
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
state_dict,
network_alphas=network_alphas,
unet=self.unet,
config=metadata,
adapter_name=adapter_name,
_pipeline=self,
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
@@ -1319,6 +1383,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
config=metadata,
prefix="text_encoder",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
@@ -1331,6 +1396,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder_2_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
config=metadata,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
@@ -1344,6 +1410,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
unet_lora_config=None,
text_encoder_lora_config=None,
text_encoder_2_lora_config=None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -1371,24 +1440,63 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
state_dict = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
if not USE_PEFT_BACKEND and not safe_serialization:
if unet_lora_config or text_encoder_lora_config or text_encoder_2_lora_config:
raise ValueError(
"Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` or `text_encoder_2_lora_config` is not possible. Please install `peft`."
)
elif USE_PEFT_BACKEND and safe_serialization:
from peft import LoraConfig
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
)
state_dict = {}
metadata = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
def pack_metadata(config, prefix):
local_metadata = {}
if config is not None:
if isinstance(config, LoraConfig):
config = config.to_dict()
for key, value in config.items():
if isinstance(value, set):
config[key] = list(value)
config_as_string = json.dumps(config, indent=2, sort_keys=True)
local_metadata[prefix] = config_as_string
return local_metadata
if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))
prefix = "unet"
unet_state_dict = pack_weights(unet_lora_layers, prefix)
state_dict.update(unet_state_dict)
if unet_lora_config is not None:
unet_metadata = pack_metadata(unet_lora_config, prefix)
metadata.update(unet_metadata)
if text_encoder_lora_layers and text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
prefix = "text_encoder"
text_encoder_state_dict = pack_weights(text_encoder_lora_layers, "text_encoder")
state_dict.update(text_encoder_state_dict)
if text_encoder_lora_config is not None:
text_encoder_metadata = pack_metadata(text_encoder_lora_config, prefix)
metadata.update(text_encoder_metadata)
prefix = "text_encoder_2"
text_encoder_2_state_dict = pack_weights(text_encoder_2_lora_layers, prefix)
state_dict.update(text_encoder_2_state_dict)
if text_encoder_2_lora_config is not None:
text_encoder_2_metadata = pack_metadata(text_encoder_2_lora_config, prefix)
metadata.update(text_encoder_2_metadata)
cls.write_lora_layers(
state_dict=state_dict,
@@ -1397,6 +1505,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
metadata=metadata,
)
def _remove_text_encoder_monkey_patch(self):

View File

@@ -138,11 +138,17 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
module.set_scale(adapter_name, 1.0)
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, is_unet=True):
rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
# Try to retrive config.
alpha_retrieved = False
if config is not None:
lora_alpha = config["lora_alpha"]
alpha_retrieved = True
if len(set(rank_dict.values())) > 1:
# get the rank occuring the most number of times
r = collections.Counter(rank_dict.values()).most_common()[0][0]
@@ -154,7 +160,8 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
if network_alpha_dict is not None and len(network_alpha_dict) > 0:
if len(set(network_alpha_dict.values())) > 1:
# get the alpha occuring the most number of times
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
if not alpha_retrieved:
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
@@ -165,7 +172,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
}
else:
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
else:
elif not alpha_retrieved:
lora_alpha = set(network_alpha_dict.values()).pop()
# layer names without the Diffusers specific

View File

@@ -107,8 +107,9 @@ class PeftLoraLoaderMixinTests:
unet_kwargs = None
vae_kwargs = None
def get_dummy_components(self, scheduler_cls=None):
def get_dummy_components(self, scheduler_cls=None, lora_alpha=None):
scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler
lora_alpha = 4 if lora_alpha is None else lora_alpha
torch.manual_seed(0)
unet = UNet2DConditionModel(**self.unet_kwargs)
@@ -123,11 +124,14 @@ class PeftLoraLoaderMixinTests:
tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
text_lora_config = LoraConfig(
r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False
r=4,
lora_alpha=lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
init_lora_weights=False,
)
unet_lora_config = LoraConfig(
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
r=4, lora_alpha=lora_alpha, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
)
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
@@ -714,6 +718,68 @@ class PeftLoraLoaderMixinTests:
"Fused lora should change the output",
)
def test_if_lora_alpha_is_correctly_parsed(self):
lora_alpha = 8
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.unet.add_adapter(unet_lora_config)
pipe.text_encoder.add_adapter(text_lora_config)
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
# Inference works?
_ = pipe(**inputs, generator=torch.manual_seed(0)).images
with tempfile.TemporaryDirectory() as tmpdirname:
unet_state_dict = get_peft_model_state_dict(pipe.unet)
unet_lora_config = pipe.unet.peft_config["default"]
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
text_encoder_lora_config = pipe.text_encoder.peft_config["default"]
if self.has_two_text_encoders:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
text_encoder_2_lora_config = pipe.text_encoder_2.peft_config["default"]
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=unet_state_dict,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_lora_config,
text_encoder_2_lora_config=text_encoder_2_lora_config,
)
else:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=unet_state_dict,
text_encoder_lora_layers=text_encoder_state_dict,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_lora_config,
)
loaded_pipe = self.pipeline_class(**components)
loaded_pipe.load_lora_weights(tmpdirname)
# Inference works?
_ = loaded_pipe(**inputs, generator=torch.manual_seed(0)).images
assert (
loaded_pipe.unet.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for UNet."
assert (
loaded_pipe.text_encoder.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for text encoder."
if self.has_two_text_encoders:
assert (
loaded_pipe.text_encoder_2.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for text encoder 2."
def test_simple_inference_with_text_unet_lora_unfused(self):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights