mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
Compare commits
32 Commits
controlnet
...
fix/lora-l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f9c7b327cb | ||
|
|
255adf0466 | ||
|
|
5d04eebd1f | ||
|
|
f145d48ed7 | ||
|
|
765fef7134 | ||
|
|
8c98a187c7 | ||
|
|
ece6d89cf2 | ||
|
|
16ac1b2f4f | ||
|
|
ec9df6fc48 | ||
|
|
09618d09a6 | ||
|
|
d24e7d3ea9 | ||
|
|
57a16f35ee | ||
|
|
f4adaae5cb | ||
|
|
49a0f3ab02 | ||
|
|
24cb282d36 | ||
|
|
bcf0f4a789 | ||
|
|
fdb114618d | ||
|
|
ed333f06ae | ||
|
|
32212b6df6 | ||
|
|
9ecb271ac8 | ||
|
|
a2792cd942 | ||
|
|
c341111d69 | ||
|
|
b868e8a2fc | ||
|
|
41b9cd8787 | ||
|
|
0d08249c9f | ||
|
|
3b27b23082 | ||
|
|
e4c00bc5c2 | ||
|
|
cf132fb6b0 | ||
|
|
981ea82591 | ||
|
|
20fac7bc9d | ||
|
|
79b16373b7 | ||
|
|
ff3d380824 |
@@ -880,11 +880,16 @@ def main(args):
|
||||
unet_lora_layers_to_save = None
|
||||
text_encoder_lora_layers_to_save = None
|
||||
|
||||
unet_lora_config = None
|
||||
text_encoder_lora_config = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
unet_lora_config = model.peft_config["default"]
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
||||
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
text_encoder_lora_config = model.peft_config["default"]
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -895,6 +900,8 @@ def main(args):
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
||||
unet_lora_config=unet_lora_config,
|
||||
text_encoder_lora_config=text_encoder_lora_config,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
@@ -911,10 +918,12 @@ def main(args):
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(
|
||||
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
|
||||
)
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_, config=metadata
|
||||
)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
@@ -1315,17 +1324,22 @@ def main(args):
|
||||
unet = unet.to(torch.float32)
|
||||
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
unet_lora_config = unet.peft_config["default"]
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
|
||||
text_encoder_lora_config = text_encoder.peft_config["default"]
|
||||
else:
|
||||
text_encoder_state_dict = None
|
||||
text_encoder_lora_config = None
|
||||
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
text_encoder_lora_layers=text_encoder_state_dict,
|
||||
unet_lora_config=unet_lora_config,
|
||||
text_encoder_lora_config=text_encoder_lora_config,
|
||||
)
|
||||
|
||||
# Final inference
|
||||
|
||||
@@ -1033,13 +1033,20 @@ def main(args):
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
|
||||
unet_lora_config = None
|
||||
text_encoder_one_lora_config = None
|
||||
text_encoder_two_lora_config = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
unet_lora_config = model.peft_config["default"]
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
text_encoder_one_lora_config = model.peft_config["default"]
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
text_encoder_two_lora_config = model.peft_config["default"]
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1051,6 +1058,9 @@ def main(args):
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||
unet_lora_config=unet_lora_config,
|
||||
text_encoder_lora_config=text_encoder_one_lora_config,
|
||||
text_encoder_2_lora_config=text_encoder_two_lora_config,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
@@ -1070,17 +1080,19 @@ def main(args):
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(
|
||||
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
|
||||
)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
||||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
||||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata
|
||||
)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
@@ -1616,21 +1628,29 @@ def main(args):
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = get_peft_model_state_dict(unet)
|
||||
unet_lora_config = unet.peft_config["default"]
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
||||
text_encoder_one_lora_config = text_encoder_one.peft_config["default"]
|
||||
text_encoder_two_lora_config = text_encoder_two.peft_config["default"]
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
text_encoder_one_lora_config = None
|
||||
text_encoder_two_lora_config = None
|
||||
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
||||
unet_lora_config=unet_lora_config,
|
||||
text_encoder_lora_config=text_encoder_one_lora_config,
|
||||
text_encoder_2_lora_config=text_encoder_two_lora_config,
|
||||
)
|
||||
|
||||
# Final inference
|
||||
|
||||
@@ -833,10 +833,12 @@ def main():
|
||||
accelerator.save_state(save_path)
|
||||
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
unet_lora_config = unet.peft_config["default"]
|
||||
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=save_path,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
unet_lora_config=unet_lora_config,
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
@@ -898,10 +900,12 @@ def main():
|
||||
unet = unet.to(torch.float32)
|
||||
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
unet_lora_config = unet.peft_config["default"]
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
safe_serialization=True,
|
||||
unet_lora_config=unet_lora_config,
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
|
||||
@@ -682,13 +682,20 @@ def main(args):
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
|
||||
unet_lora_config = None
|
||||
text_encoder_one_lora_config = None
|
||||
text_encoder_two_lora_config = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
unet_lora_config = model.peft_config["default"]
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
text_encoder_one_lora_config = model.peft_config["default"]
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
text_encoder_two_lora_config = model.peft_config["default"]
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -700,6 +707,9 @@ def main(args):
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||
unet_lora_config=unet_lora_config,
|
||||
text_encoder_lora_config=text_encoder_one_lora_config,
|
||||
text_encoder_2_lora_config=text_encoder_two_lora_config,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
@@ -719,17 +729,19 @@ def main(args):
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(
|
||||
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
|
||||
)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
||||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
||||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata
|
||||
)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
@@ -1194,6 +1206,7 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
unet_lora_config = unet.peft_config["default"]
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
@@ -1201,15 +1214,23 @@ def main(args):
|
||||
|
||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
|
||||
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
|
||||
|
||||
text_encoder_one_lora_config = text_encoder_one.peft_config["default"]
|
||||
text_encoder_two_lora_config = text_encoder_two.peft_config["default"]
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
text_encoder_one_lora_config = None
|
||||
text_encoder_two_lora_config = None
|
||||
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
||||
unet_lora_config=unet_lora_config,
|
||||
text_encoder_lora_config=text_encoder_one_lora_config,
|
||||
text_encoder_2_lora_config=text_encoder_two_lora_config,
|
||||
)
|
||||
|
||||
del unet
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
@@ -103,7 +104,7 @@ class LoraLoaderMixin:
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
@@ -114,6 +115,7 @@ class LoraLoaderMixin:
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
config=metadata,
|
||||
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
adapter_name=adapter_name,
|
||||
@@ -125,6 +127,7 @@ class LoraLoaderMixin:
|
||||
text_encoder=getattr(self, self.text_encoder_name)
|
||||
if not hasattr(self, "text_encoder")
|
||||
else self.text_encoder,
|
||||
config=metadata,
|
||||
lora_scale=self.lora_scale,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
adapter_name=adapter_name,
|
||||
@@ -219,6 +222,7 @@ class LoraLoaderMixin:
|
||||
}
|
||||
|
||||
model_file = None
|
||||
metadata = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
# Let's first try to load .safetensors weights
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
@@ -248,6 +252,8 @@ class LoraLoaderMixin:
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
with safetensors.safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
except (IOError, safetensors.SafetensorError) as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
@@ -294,7 +300,7 @@ class LoraLoaderMixin:
|
||||
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
||||
state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict)
|
||||
|
||||
return state_dict, network_alphas
|
||||
return state_dict, network_alphas, metadata
|
||||
|
||||
@classmethod
|
||||
def _best_guess_weight_name(
|
||||
@@ -370,7 +376,7 @@ class LoraLoaderMixin:
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(
|
||||
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
|
||||
cls, state_dict, network_alphas, unet, config=None, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
@@ -384,6 +390,8 @@ class LoraLoaderMixin:
|
||||
See `LoRALinearLayer` for more details.
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model to load the LoRA layers into.
|
||||
config: (`Dict`):
|
||||
LoRA configuration parsed from the state dict.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
@@ -443,7 +451,9 @@ class LoraLoaderMixin:
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
|
||||
if config is not None and isinstance(config, dict) and len(config) > 0:
|
||||
config = json.loads(config["unet"])
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True)
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
@@ -484,6 +494,7 @@ class LoraLoaderMixin:
|
||||
network_alphas,
|
||||
text_encoder,
|
||||
prefix=None,
|
||||
config=None,
|
||||
lora_scale=1.0,
|
||||
low_cpu_mem_usage=None,
|
||||
adapter_name=None,
|
||||
@@ -502,6 +513,8 @@ class LoraLoaderMixin:
|
||||
The text encoder model to load the LoRA layers into.
|
||||
prefix (`str`):
|
||||
Expected prefix of the `text_encoder` in the `state_dict`.
|
||||
config (`Dict`):
|
||||
LoRA configuration parsed from state dict.
|
||||
lora_scale (`float`):
|
||||
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
||||
lora layer.
|
||||
@@ -575,10 +588,11 @@ class LoraLoaderMixin:
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft import LoraConfig
|
||||
|
||||
if config is not None and len(config) > 0:
|
||||
config = json.loads(config[prefix])
|
||||
lora_config_kwargs = get_peft_kwargs(
|
||||
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
|
||||
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
|
||||
)
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
@@ -786,6 +800,8 @@ class LoraLoaderMixin:
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
unet_lora_config=None,
|
||||
text_encoder_lora_config=None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -813,21 +829,54 @@ class LoraLoaderMixin:
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND and not safe_serialization:
|
||||
if unet_lora_config or text_encoder_lora_config:
|
||||
raise ValueError(
|
||||
"Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` is not possible. Please install `peft`."
|
||||
)
|
||||
elif USE_PEFT_BACKEND and safe_serialization:
|
||||
from peft import LoraConfig
|
||||
|
||||
if not (unet_lora_layers or text_encoder_lora_layers):
|
||||
raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
state_dict = {}
|
||||
metadata = {}
|
||||
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
|
||||
return layers_state_dict
|
||||
|
||||
if not (unet_lora_layers or text_encoder_lora_layers):
|
||||
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")
|
||||
def pack_metadata(config, prefix):
|
||||
local_metadata = {}
|
||||
if config is not None:
|
||||
if isinstance(config, LoraConfig):
|
||||
config = config.to_dict()
|
||||
for key, value in config.items():
|
||||
if isinstance(value, set):
|
||||
config[key] = list(value)
|
||||
|
||||
config_as_string = json.dumps(config, indent=2, sort_keys=True)
|
||||
local_metadata[prefix] = config_as_string
|
||||
return local_metadata
|
||||
|
||||
if unet_lora_layers:
|
||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||
prefix = "unet"
|
||||
unet_state_dict = pack_weights(unet_lora_layers, prefix)
|
||||
state_dict.update(unet_state_dict)
|
||||
if unet_lora_config is not None:
|
||||
unet_metadata = pack_metadata(unet_lora_config, prefix)
|
||||
metadata.update(unet_metadata)
|
||||
|
||||
if text_encoder_lora_layers:
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
prefix = "text_encoder"
|
||||
text_encoder_state_dict = pack_weights(text_encoder_lora_layers, "text_encoder")
|
||||
state_dict.update(text_encoder_state_dict)
|
||||
if text_encoder_lora_config is not None:
|
||||
text_encoder_metadata = pack_metadata(text_encoder_lora_config, prefix)
|
||||
metadata.update(text_encoder_metadata)
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
@@ -837,6 +886,7 @@ class LoraLoaderMixin:
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -847,7 +897,11 @@ class LoraLoaderMixin:
|
||||
weight_name: str,
|
||||
save_function: Callable,
|
||||
safe_serialization: bool,
|
||||
metadata=None,
|
||||
):
|
||||
if not safe_serialization and isinstance(metadata, dict) and len(metadata) > 0:
|
||||
raise ValueError("Passing `metadata` is not possible when `safe_serialization` is False.")
|
||||
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
@@ -855,8 +909,10 @@ class LoraLoaderMixin:
|
||||
if save_function is None:
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
def save_function(weights, filename, metadata):
|
||||
if metadata is None:
|
||||
metadata = {"format": "pt"}
|
||||
return safetensors.torch.save_file(weights, filename, metadata=metadata)
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
@@ -869,7 +925,10 @@ class LoraLoaderMixin:
|
||||
else:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
if save_function != torch.save:
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name), metadata)
|
||||
else:
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
||||
|
||||
def unload_lora_weights(self):
|
||||
@@ -1301,7 +1360,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
# pipeline.
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas = self.lora_state_dict(
|
||||
state_dict, network_alphas, metadata = self.lora_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
unet_config=self.unet.config,
|
||||
**kwargs,
|
||||
@@ -1311,7 +1370,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
unet=self.unet,
|
||||
config=metadata,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
)
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
@@ -1319,6 +1383,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
text_encoder_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
config=metadata,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
@@ -1331,6 +1396,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
text_encoder_2_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
config=metadata,
|
||||
prefix="text_encoder_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
@@ -1344,6 +1410,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
unet_lora_config=None,
|
||||
text_encoder_lora_config=None,
|
||||
text_encoder_2_lora_config=None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -1371,24 +1440,63 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
"""
|
||||
state_dict = {}
|
||||
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
if not USE_PEFT_BACKEND and not safe_serialization:
|
||||
if unet_lora_config or text_encoder_lora_config or text_encoder_2_lora_config:
|
||||
raise ValueError(
|
||||
"Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` or `text_encoder_2_lora_config` is not possible. Please install `peft`."
|
||||
)
|
||||
elif USE_PEFT_BACKEND and safe_serialization:
|
||||
from peft import LoraConfig
|
||||
|
||||
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
||||
raise ValueError(
|
||||
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
|
||||
)
|
||||
|
||||
state_dict = {}
|
||||
metadata = {}
|
||||
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
|
||||
return layers_state_dict
|
||||
|
||||
def pack_metadata(config, prefix):
|
||||
local_metadata = {}
|
||||
if config is not None:
|
||||
if isinstance(config, LoraConfig):
|
||||
config = config.to_dict()
|
||||
for key, value in config.items():
|
||||
if isinstance(value, set):
|
||||
config[key] = list(value)
|
||||
|
||||
config_as_string = json.dumps(config, indent=2, sort_keys=True)
|
||||
local_metadata[prefix] = config_as_string
|
||||
return local_metadata
|
||||
|
||||
if unet_lora_layers:
|
||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||
prefix = "unet"
|
||||
unet_state_dict = pack_weights(unet_lora_layers, prefix)
|
||||
state_dict.update(unet_state_dict)
|
||||
if unet_lora_config is not None:
|
||||
unet_metadata = pack_metadata(unet_lora_config, prefix)
|
||||
metadata.update(unet_metadata)
|
||||
|
||||
if text_encoder_lora_layers and text_encoder_2_lora_layers:
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||
prefix = "text_encoder"
|
||||
text_encoder_state_dict = pack_weights(text_encoder_lora_layers, "text_encoder")
|
||||
state_dict.update(text_encoder_state_dict)
|
||||
if text_encoder_lora_config is not None:
|
||||
text_encoder_metadata = pack_metadata(text_encoder_lora_config, prefix)
|
||||
metadata.update(text_encoder_metadata)
|
||||
|
||||
prefix = "text_encoder_2"
|
||||
text_encoder_2_state_dict = pack_weights(text_encoder_2_lora_layers, prefix)
|
||||
state_dict.update(text_encoder_2_state_dict)
|
||||
if text_encoder_2_lora_config is not None:
|
||||
text_encoder_2_metadata = pack_metadata(text_encoder_2_lora_config, prefix)
|
||||
metadata.update(text_encoder_2_metadata)
|
||||
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
@@ -1397,6 +1505,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
|
||||
@@ -138,11 +138,17 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
|
||||
module.set_scale(adapter_name, 1.0)
|
||||
|
||||
|
||||
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
|
||||
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, is_unet=True):
|
||||
rank_pattern = {}
|
||||
alpha_pattern = {}
|
||||
r = lora_alpha = list(rank_dict.values())[0]
|
||||
|
||||
# Try to retrive config.
|
||||
alpha_retrieved = False
|
||||
if config is not None:
|
||||
lora_alpha = config["lora_alpha"]
|
||||
alpha_retrieved = True
|
||||
|
||||
if len(set(rank_dict.values())) > 1:
|
||||
# get the rank occuring the most number of times
|
||||
r = collections.Counter(rank_dict.values()).most_common()[0][0]
|
||||
@@ -154,7 +160,8 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
|
||||
if network_alpha_dict is not None and len(network_alpha_dict) > 0:
|
||||
if len(set(network_alpha_dict.values())) > 1:
|
||||
# get the alpha occuring the most number of times
|
||||
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
|
||||
if not alpha_retrieved:
|
||||
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
|
||||
|
||||
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
|
||||
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
|
||||
@@ -165,7 +172,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
|
||||
}
|
||||
else:
|
||||
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
|
||||
else:
|
||||
elif not alpha_retrieved:
|
||||
lora_alpha = set(network_alpha_dict.values()).pop()
|
||||
|
||||
# layer names without the Diffusers specific
|
||||
|
||||
@@ -107,8 +107,9 @@ class PeftLoraLoaderMixinTests:
|
||||
unet_kwargs = None
|
||||
vae_kwargs = None
|
||||
|
||||
def get_dummy_components(self, scheduler_cls=None):
|
||||
def get_dummy_components(self, scheduler_cls=None, lora_alpha=None):
|
||||
scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler
|
||||
lora_alpha = 4 if lora_alpha is None else lora_alpha
|
||||
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(**self.unet_kwargs)
|
||||
@@ -123,11 +124,14 @@ class PeftLoraLoaderMixinTests:
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
|
||||
|
||||
text_lora_config = LoraConfig(
|
||||
r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False
|
||||
r=4,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
init_lora_weights=False,
|
||||
)
|
||||
|
||||
unet_lora_config = LoraConfig(
|
||||
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
|
||||
r=4, lora_alpha=lora_alpha, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
|
||||
)
|
||||
|
||||
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
|
||||
@@ -714,6 +718,68 @@ class PeftLoraLoaderMixinTests:
|
||||
"Fused lora should change the output",
|
||||
)
|
||||
|
||||
def test_if_lora_alpha_is_correctly_parsed(self):
|
||||
lora_alpha = 8
|
||||
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.unet.add_adapter(unet_lora_config)
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
if self.has_two_text_encoders:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
|
||||
# Inference works?
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
unet_state_dict = get_peft_model_state_dict(pipe.unet)
|
||||
unet_lora_config = pipe.unet.peft_config["default"]
|
||||
|
||||
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
|
||||
text_encoder_lora_config = pipe.text_encoder.peft_config["default"]
|
||||
|
||||
if self.has_two_text_encoders:
|
||||
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
|
||||
text_encoder_2_lora_config = pipe.text_encoder_2.peft_config["default"]
|
||||
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
unet_lora_layers=unet_state_dict,
|
||||
text_encoder_lora_layers=text_encoder_state_dict,
|
||||
text_encoder_2_lora_layers=text_encoder_2_state_dict,
|
||||
unet_lora_config=unet_lora_config,
|
||||
text_encoder_lora_config=text_encoder_lora_config,
|
||||
text_encoder_2_lora_config=text_encoder_2_lora_config,
|
||||
)
|
||||
else:
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
unet_lora_layers=unet_state_dict,
|
||||
text_encoder_lora_layers=text_encoder_state_dict,
|
||||
unet_lora_config=unet_lora_config,
|
||||
text_encoder_lora_config=text_encoder_lora_config,
|
||||
)
|
||||
loaded_pipe = self.pipeline_class(**components)
|
||||
loaded_pipe.load_lora_weights(tmpdirname)
|
||||
|
||||
# Inference works?
|
||||
_ = loaded_pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
assert (
|
||||
loaded_pipe.unet.peft_config["default"].lora_alpha == lora_alpha
|
||||
), "LoRA alpha not correctly loaded for UNet."
|
||||
assert (
|
||||
loaded_pipe.text_encoder.peft_config["default"].lora_alpha == lora_alpha
|
||||
), "LoRA alpha not correctly loaded for text encoder."
|
||||
if self.has_two_text_encoders:
|
||||
assert (
|
||||
loaded_pipe.text_encoder_2.peft_config["default"].lora_alpha == lora_alpha
|
||||
), "LoRA alpha not correctly loaded for text encoder 2."
|
||||
|
||||
def test_simple_inference_with_text_unet_lora_unfused(self):
|
||||
"""
|
||||
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
|
||||
|
||||
Reference in New Issue
Block a user