Compare commits

...

19 Commits

Author SHA1 Message Date
Sayak Paul
9d411fd5c2 Merge branch 'main' into feat-config-metadata 2024-08-21 15:47:28 +05:30
sayakpaul
178a4596b2 add: comment. 2024-08-12 18:40:23 +05:30
sayakpaul
e7808e4c35 fix 2024-08-12 16:25:30 +05:30
Sayak Paul
d79c0d5970 Merge branch 'main' into feat-config-metadata 2024-08-12 15:25:59 +05:30
sayakpaul
f7d30de31a alpha_pattern and rank+pattern 2024-08-12 15:25:43 +05:30
sayakpaul
712a110e76 style 2024-08-10 23:18:29 +05:30
sayakpaul
1852c3fa3b style 2024-08-10 23:14:21 +05:30
sayakpaul
6fb5987e40 utilize fix-copies better. 2024-08-10 23:06:21 +05:30
sayakpaul
91cfffc54b check, 2024-08-10 18:55:24 +05:30
sayakpaul
632bf78c04 fix-copies 2024-08-10 18:50:39 +05:30
sayakpaul
48768b66b6 style 2024-08-10 08:34:40 +05:30
Sayak Paul
225634e8d1 Merge branch 'main' into feat-config-metadata 2024-08-10 08:14:03 +05:30
sayakpaul
6bbf629349 style 2024-08-10 08:13:28 +05:30
sayakpaul
167557b9bf fix-copues 2024-08-10 08:13:08 +05:30
sayakpaul
733e1d9259 documentation 2024-08-10 08:10:29 +05:30
sayakpaul
79aff1d82b fix lora_alpha 2024-08-09 20:15:47 +05:30
Sayak Paul
fb9e86d806 Apply suggestions from code review
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
2024-08-09 20:04:01 +05:30
sayakpaul
2b9f77e2d3 fix flux test 2024-08-09 19:58:29 +05:30
sayakpaul
1819000151 feat: add non-breaking support to serialize metadata in loras. 2024-08-09 18:41:56 +05:30
6 changed files with 368 additions and 55 deletions

View File

@@ -14,6 +14,7 @@
import copy
import inspect
import json
import os
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
@@ -26,6 +27,7 @@ from huggingface_hub.constants import HF_HUB_OFFLINE
from ..models.modeling_utils import ModelMixin, load_state_dict
from ..utils import (
SAFETENSORS_FILE_EXTENSION,
USE_PEFT_BACKEND,
_get_model_file,
delete_adapter_layers,
@@ -44,6 +46,7 @@ if is_transformers_available():
from transformers import PreTrainedModel
if is_peft_available():
from peft import LoraConfig
from peft.tuners.tuners_utils import BaseTunerLayer
if is_accelerate_available():
@@ -252,6 +255,7 @@ class LoraBaseMixin:
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
model_file = None
metadata = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
@@ -280,6 +284,8 @@ class LoraBaseMixin:
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
with safetensors.safe_open(model_file, framework="pt", device="cpu") as f:
metadata = f.metadata()
except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
@@ -305,10 +311,14 @@ class LoraBaseMixin:
user_agent=user_agent,
)
state_dict = load_state_dict(model_file)
file_extension = os.path.basename(model_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
with safetensors.safe_open(model_file, framework="pt", device="cpu") as f:
metadata = f.metadata()
else:
state_dict = pretrained_model_name_or_path_or_dict
return state_dict
return state_dict, metadata
@classmethod
def _best_guess_weight_name(
@@ -709,6 +719,20 @@ class LoraBaseMixin:
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
@staticmethod
def pack_metadata(config, prefix):
local_metadata = {}
if config is not None:
if isinstance(config, LoraConfig):
config = config.to_dict()
for key, value in config.items():
if isinstance(value, set):
config[key] = list(value)
config_as_string = json.dumps(config, indent=2, sort_keys=True)
local_metadata[prefix] = config_as_string
return local_metadata
@staticmethod
def write_lora_layers(
state_dict: Dict[str, torch.Tensor],
@@ -717,9 +741,13 @@ class LoraBaseMixin:
weight_name: str,
save_function: Callable,
safe_serialization: bool,
metadata=None,
):
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
if not safe_serialization and isinstance(metadata, dict) and len(metadata) > 0:
raise ValueError("Passing `metadata` is not possible when `safe_serialization` is False.")
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
@@ -727,8 +755,12 @@ class LoraBaseMixin:
if save_function is None:
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
def save_function(weights, filename, metadata):
if metadata is None:
metadata = {"format": "pt"}
elif len(metadata) > 0:
metadata.update({"format": "pt"})
return safetensors.torch.save_file(weights, filename, metadata=metadata)
else:
save_function = torch.save
@@ -742,7 +774,10 @@ class LoraBaseMixin:
weight_name = LORA_WEIGHT_NAME
save_path = Path(save_directory, weight_name).as_posix()
save_function(state_dict, save_path)
if save_function != torch.save:
save_function(state_dict, save_path, metadata)
else:
save_function(state_dict, save_path)
logger.info(f"Model weights saved in {save_path}")
@property

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from typing import Callable, Dict, List, Optional, Union
@@ -92,7 +93,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
state_dict, network_alphas, metadata = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, **kwargs, return_metadata=True
)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format:
@@ -104,6 +107,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
adapter_name=adapter_name,
_pipeline=self,
config=metadata,
)
self.load_lora_into_text_encoder(
state_dict,
@@ -113,6 +117,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
else self.text_encoder,
lora_scale=self.lora_scale,
adapter_name=adapter_name,
config=metadata,
_pipeline=self,
)
@@ -168,6 +173,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
The subfolder location of a model file within a larger model repository on the Hub or locally.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
return_metadata (`bool`):
If state dict metadata should be returned. Is only supported when the state dict has a safetensors
extension.
"""
# Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both.
@@ -181,6 +189,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
weight_name = kwargs.pop("weight_name", None)
unet_config = kwargs.pop("unet_config", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_metadata = kwargs.pop("return_metadata", False)
allow_pickle = False
if use_safetensors is None:
@@ -192,7 +201,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch",
}
state_dict = cls._fetch_state_dict(
state_dict, metadata = cls._fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -224,10 +233,13 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
return state_dict, network_alphas
if return_metadata:
return state_dict, network_alphas, metadata
else:
return state_dict, network_alphas
@classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, config=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -245,6 +257,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -258,7 +271,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
unet.load_attn_procs(
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
config=config,
_pipeline=_pipeline,
)
@classmethod
@@ -270,6 +287,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
prefix=None,
lora_scale=1.0,
adapter_name=None,
config=None,
_pipeline=None,
):
"""
@@ -291,6 +309,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -303,6 +322,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
if config is not None and len(config) > 0:
config = json.loads(config[prefix])
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
@@ -341,7 +363,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
@@ -385,6 +409,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
unet_lora_config: dict = None,
text_encoder_lora_config: dict = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -401,6 +427,10 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
unet_lora_config (`dict`, *optional*):
LoRA configuration used to train the `unet_lora_layers`.
text_encoder_lora_config (`dict`, *optional*):
LoRA configuration used to train the `text_encoder_lora_layers`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
@@ -413,19 +443,27 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
state_dict = {}
metadata = {}
if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
if unet_lora_layers:
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
if unet_lora_config:
unet_metadata = cls.pack_metadata(unet_lora_config, cls.unet_name)
metadata.update(unet_metadata)
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
if text_encoder_lora_config:
te_metadata = cls.pack_metadata(text_encoder_lora_config, cls.text_encoder_name)
metadata.update(te_metadata)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
metadata=metadata,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
@@ -489,10 +527,6 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
@@ -550,9 +584,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(
state_dict, network_alphas, metadata = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
return_metadata=True,
**kwargs,
)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
@@ -560,7 +595,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet(
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
state_dict,
network_alphas=network_alphas,
unet=self.unet,
adapter_name=adapter_name,
config=metadata,
_pipeline=self,
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
@@ -571,6 +611,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
prefix="text_encoder",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
config=metadata,
_pipeline=self,
)
@@ -583,6 +624,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
prefix="text_encoder_2",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
config=metadata,
_pipeline=self,
)
@@ -639,6 +681,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
The subfolder location of a model file within a larger model repository on the Hub or locally.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
return_metadata (`bool`):
If state dict metadata should be returned. Is only supported when the state dict has a safetensors
extension.
"""
# Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both.
@@ -652,6 +697,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
weight_name = kwargs.pop("weight_name", None)
unet_config = kwargs.pop("unet_config", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_metadata = kwargs.pop("return_metadata", False)
allow_pickle = False
if use_safetensors is None:
@@ -663,7 +709,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch",
}
state_dict = cls._fetch_state_dict(
state_dict, metadata = cls._fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -695,11 +741,14 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
return state_dict, network_alphas
if return_metadata:
return state_dict, network_alphas, metadata
else:
return state_dict, network_alphas
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, config=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -717,6 +766,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -730,7 +780,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
unet.load_attn_procs(
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
config=config,
_pipeline=_pipeline,
)
@classmethod
@@ -743,6 +797,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
prefix=None,
lora_scale=1.0,
adapter_name=None,
config=None,
_pipeline=None,
):
"""
@@ -764,6 +819,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -776,6 +832,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
if config is not None and len(config) > 0:
config = json.loads(config[prefix])
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
@@ -814,7 +873,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
@@ -859,6 +920,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
unet_lora_config: dict = None,
text_encoder_lora_config: dict = None,
text_encoder_2_lora_config: dict = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -878,6 +942,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
unet_lora_config (`dict`, *optional*):
LoRA configuration used to train the `unet_lora_layers`.
text_encoder_lora_config (`dict`, *optional*):
LoRA configuration used to train the `text_encoder_lora_layers`.
text_encoder_2_lora_config (`dict`, *optional*):
LoRA configuration used to train the `text_encoder_2_lora_layers`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
@@ -890,6 +960,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
state_dict = {}
metadata = {}
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
@@ -898,15 +969,25 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
if unet_lora_layers:
state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
if unet_lora_config is not None:
unet_metadata = cls.pack_metadata(unet_lora_config, "unet")
metadata.update(unet_metadata)
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
if text_encoder_lora_config is not None:
te_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder")
metadata.update(te_metadata)
if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
if text_encoder_2_lora_config is not None:
te2_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder_2")
metadata.update(te2_metadata)
cls.write_lora_layers(
state_dict=state_dict,
metadata=metadata,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
@@ -970,10 +1051,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
@@ -1041,6 +1118,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_metadata (`bool`, *optional*):
If state dict metadata should be returned. Is only supported when the state dict has a safetensors
extension.
"""
# Load the main state dict first which has the LoRA layers for either of
@@ -1054,6 +1134,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_metadata = kwargs.pop("return_metadata", False)
allow_pickle = False
if use_safetensors is None:
@@ -1065,7 +1146,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch",
}
state_dict = cls._fetch_state_dict(
state_dict, metadata = cls._fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -1080,13 +1161,17 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
allow_pickle=allow_pickle,
)
return state_dict
# Otherwise, this would be a breaking change.
if return_metadata:
return state_dict, metadata
else:
return state_dict
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
`self.text_encoder`.
All kwargs are forwarded to `self.lora_state_dict`.
@@ -1114,7 +1199,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
state_dict, metadata = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, **kwargs, return_metadata=True
)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format:
@@ -1124,6 +1211,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
config=metadata,
_pipeline=self,
)
@@ -1136,6 +1224,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
prefix="text_encoder",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
config=metadata,
_pipeline=self,
)
@@ -1148,11 +1237,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
prefix="text_encoder_2",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
config=metadata,
_pipeline=self,
)
@classmethod
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, config=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1166,6 +1256,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
config (`dict`): Configuration that was used to train this LoRA.
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
@@ -1192,7 +1283,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
if config is not None and isinstance(config, dict) and len(config) > 0:
config = json.loads(config[cls.transformer_name])
lora_config_kwargs = get_peft_kwargs(
rank, network_alpha_dict=None, config=config, peft_state_dict=state_dict
)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
@@ -1239,6 +1334,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
prefix=None,
lora_scale=1.0,
adapter_name=None,
config=None,
_pipeline=None,
):
"""
@@ -1260,6 +1356,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1272,6 +1369,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
if config is not None and len(config) > 0:
config = json.loads(config[prefix])
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
@@ -1310,7 +1410,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
@@ -1349,12 +1451,16 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
# Unsafe code />
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
transformer_lora_config: dict = None,
text_encoder_lora_config: dict = None,
text_encoder_2_lora_config: dict = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -1374,6 +1480,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
transformer_lora_config (`dict`, *optional*):
LoRA configuration used to train the `transformer_lora_layers`.
text_encoder_lora_config (`dict`, *optional*):
LoRA configuration used to train the `text_encoder_lora_layers`.
text_encoder_2_lora_config (`dict`, *optional*):
LoRA configuration used to train the `text_encoder_2_lora_layers`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
@@ -1386,24 +1498,34 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
state_dict = {}
metadata = {}
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
)
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, "transformer"))
if transformer_lora_config is not None:
transformer_metadata = cls.pack_metadata(transformer_lora_config, "transformer")
metadata.update(transformer_metadata)
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
if text_encoder_lora_config is not None:
te_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder")
metadata.update(te_metadata)
if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
if text_encoder_2_lora_config is not None:
te2_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder_2")
metadata.update(te2_metadata)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
metadata=metadata,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
@@ -1411,6 +1533,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
@@ -1454,6 +1577,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
r"""
Reverses the effect of
@@ -1467,10 +1591,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
@@ -1538,6 +1658,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_metadata (`bool`, *optional*):
If state dict metadata should be returned. Is only supported when the state dict has a safetensors
extension.
"""
# Load the main state dict first which has the LoRA layers for either of
@@ -1551,6 +1674,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_metadata = kwargs.pop("return_metadata", False)
allow_pickle = False
if use_safetensors is None:
@@ -1562,7 +1686,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch",
}
state_dict = cls._fetch_state_dict(
state_dict, metadata = cls._fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -1577,7 +1701,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
allow_pickle=allow_pickle,
)
return state_dict
# Otherwise, this would be a breaking change.
if return_metadata:
return state_dict, metadata
else:
return state_dict
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
@@ -1611,7 +1739,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
state_dict, metadata = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, **kwargs, return_metadata=True
)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format:
@@ -1621,6 +1751,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
config=metadata,
_pipeline=self,
)
@@ -1633,12 +1764,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
prefix="text_encoder",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
config=metadata,
_pipeline=self,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, config=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1652,6 +1784,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
config (`dict`): Configuration that was used to train this LoRA.
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
@@ -1678,7 +1811,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
if config is not None and isinstance(config, dict) and len(config) > 0:
config = json.loads(config[cls.transformer_name])
lora_config_kwargs = get_peft_kwargs(
rank, network_alpha_dict=None, config=config, peft_state_dict=state_dict
)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
@@ -1725,6 +1862,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
prefix=None,
lora_scale=1.0,
adapter_name=None,
config=None,
_pipeline=None,
):
"""
@@ -1746,6 +1884,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1758,6 +1897,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
if config is not None and len(config) > 0:
config = json.loads(config[prefix])
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
@@ -1796,7 +1938,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
@@ -1841,6 +1985,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
transformer_lora_config: dict = None,
text_encoder_lora_config: dict = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -1857,6 +2003,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
transformer_lora_config (`dict`, *optional*):
LoRA configuration used to train the `transformer_lora_layers`.
text_encoder_lora_config (`dict`, *optional*):
LoRA configuration used to train the `text_encoder_lora_layers`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
@@ -1869,19 +2019,27 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
state_dict = {}
metadata = {}
if not (transformer_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_config:
transformer_metadata = cls.pack_metadata(transformer_lora_config, cls.transformer_name)
metadata.update(transformer_metadata)
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
if text_encoder_lora_config:
te_metadata = cls.pack_metadata(text_encoder_lora_config, cls.text_encoder_name)
metadata.update(te_metadata)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
metadata=metadata,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
@@ -1933,6 +2091,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
Reverses the effect of
@@ -2051,6 +2210,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
prefix=None,
lora_scale=1.0,
adapter_name=None,
config=None,
_pipeline=None,
):
"""
@@ -2072,6 +2232,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -2084,6 +2245,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
if config is not None and len(config) > 0:
config = json.loads(config[prefix])
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
@@ -2122,7 +2286,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from collections import defaultdict
from contextlib import nullcontext
@@ -115,6 +116,7 @@ class UNet2DConditionLoadersMixin:
`default_{i}` where i is the total number of adapters being loaded.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
config: (`dict`, *optional*)
Example:
@@ -143,6 +145,7 @@ class UNet2DConditionLoadersMixin:
_pipeline = kwargs.pop("_pipeline", None)
network_alphas = kwargs.pop("network_alphas", None)
allow_pickle = False
config = kwargs.pop("config", None)
if use_safetensors is None:
use_safetensors = True
@@ -208,6 +211,7 @@ class UNet2DConditionLoadersMixin:
unet_identifier_key=self.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
config=config,
_pipeline=_pipeline,
)
else:
@@ -268,7 +272,7 @@ class UNet2DConditionLoadersMixin:
return attn_processors
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, config=None):
# This method does the following things:
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
# format. For legacy format no filtering is applied.
@@ -316,7 +320,10 @@ class UNet2DConditionLoadersMixin:
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
if config is not None and isinstance(config, dict) and len(config) > 0:
config = json.loads(config["unet"])
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):

View File

@@ -21,9 +21,12 @@ from typing import Optional
from packaging import version
from . import logging
from .import_utils import is_peft_available, is_torch_available
logger = logging.get_logger(__name__)
if is_torch_available():
import torch
@@ -147,11 +150,29 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
module.set_scale(adapter_name, 1.0)
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, is_unet=True):
rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
# Try to retrieve config.
alpha_retrieved = False
if config is not None:
lora_alpha = config["lora_alpha"] if "lora_alpha" in config else lora_alpha
alpha_retrieved = True
# We simply ignore the `alpha_pattern` and `rank_pattern` if they are found
# in the `config`. This is because:
# 1. We determine `rank_pattern` from the `rank_dict`.
# 2. When `network_alpha_dict` is passed that means the underlying checkpoint
# is a non-diffusers checkpoint.
# More details: https://github.com/huggingface/diffusers/pull/9143#discussion_r1711491175
if config.get("alpha_pattern", None) is not None:
logger.warning("`alpha_pattern` found in the LoRA config. This will be ignored.")
if config.get("rank_pattern", None) is not None:
logger.warning("`rank_pattern` found in the LoRA config. This will be ignored.")
if len(set(rank_dict.values())) > 1:
# get the rank occuring the most number of times
r = collections.Counter(rank_dict.values()).most_common()[0][0]
@@ -160,7 +181,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
if network_alpha_dict is not None and len(network_alpha_dict) > 0:
if not alpha_retrieved and network_alpha_dict is not None and len(network_alpha_dict) > 0:
if len(set(network_alpha_dict.values())) > 1:
# get the alpha occuring the most number of times
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]

View File

@@ -19,7 +19,7 @@ import torch
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, torch_device
sys.path.append(".")
@@ -28,6 +28,7 @@ from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.")
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()

View File

@@ -87,7 +87,7 @@ class PeftLoraLoaderMixinTests:
transformer_kwargs = None
vae_kwargs = None
def get_dummy_components(self, scheduler_cls=None, use_dora=False):
def get_dummy_components(self, scheduler_cls=None, use_dora=False, alpha=None):
if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
if self.has_two_text_encoders and self.has_three_text_encoders:
@@ -95,6 +95,7 @@ class PeftLoraLoaderMixinTests:
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
rank = 4
alpha = rank if alpha is None else alpha
torch.manual_seed(0)
if self.unet_kwargs is not None:
@@ -120,7 +121,7 @@ class PeftLoraLoaderMixinTests:
text_lora_config = LoraConfig(
r=rank,
lora_alpha=rank,
lora_alpha=alpha,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
init_lora_weights=False,
use_dora=use_dora,
@@ -128,7 +129,7 @@ class PeftLoraLoaderMixinTests:
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=rank,
lora_alpha=alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
@@ -1752,3 +1753,85 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs()
_ = pipe(**inputs).images
def test_if_lora_alpha_is_correctly_parsed(self):
lora_alpha = 8
scheduler_class = FlowMatchEulerDiscreteScheduler if self.uses_flow_matching else DDIMScheduler
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
alpha=lora_alpha, scheduler_cls=scheduler_class
)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
# Inference works?
_ = pipe(**inputs, generator=torch.manual_seed(0)).images
with tempfile.TemporaryDirectory() as tmpdirname:
denoiser = pipe.unet if self.unet_kwargs else pipe.transformer
denoiser_state_dict = get_peft_model_state_dict(denoiser)
denoiser_lora_config = denoiser.peft_config["default"]
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
text_encoder_lora_config = pipe.text_encoder.peft_config["default"]
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
text_encoder_2_lora_config = pipe.text_encoder_2.peft_config["default"]
saving_kwargs = {
"save_directory": tmpdirname,
"text_encoder_lora_layers": text_encoder_state_dict,
"text_encoder_lora_config": text_encoder_lora_config,
}
if self.unet_kwargs is not None:
saving_kwargs.update(
{"unet_lora_layers": denoiser_state_dict, "unet_lora_config": denoiser_lora_config}
)
else:
saving_kwargs.update(
{"transformer_lora_layers": denoiser_state_dict, "transformer_lora_config": denoiser_lora_config}
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
saving_kwargs.update(
{
"text_encoder_2_lora_layers": text_encoder_2_state_dict,
"text_encoder_2_lora_config": text_encoder_2_lora_config,
}
)
self.pipeline_class.save_lora_weights(**saving_kwargs)
loaded_pipe = self.pipeline_class(**components)
loaded_pipe.load_lora_weights(tmpdirname)
# Inference works?
_ = loaded_pipe(**inputs, generator=torch.manual_seed(0)).images
denoiser_loaded = pipe.unet if self.unet_kwargs is not None else pipe.transformer
assert (
denoiser_loaded.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for UNet."
assert (
loaded_pipe.text_encoder.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for text encoder."
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
assert (
loaded_pipe.text_encoder_2.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for text encoder 2."