Compare commits

...

19 Commits

Author SHA1 Message Date
Sayak Paul
9d411fd5c2 Merge branch 'main' into feat-config-metadata 2024-08-21 15:47:28 +05:30
sayakpaul
178a4596b2 add: comment. 2024-08-12 18:40:23 +05:30
sayakpaul
e7808e4c35 fix 2024-08-12 16:25:30 +05:30
Sayak Paul
d79c0d5970 Merge branch 'main' into feat-config-metadata 2024-08-12 15:25:59 +05:30
sayakpaul
f7d30de31a alpha_pattern and rank+pattern 2024-08-12 15:25:43 +05:30
sayakpaul
712a110e76 style 2024-08-10 23:18:29 +05:30
sayakpaul
1852c3fa3b style 2024-08-10 23:14:21 +05:30
sayakpaul
6fb5987e40 utilize fix-copies better. 2024-08-10 23:06:21 +05:30
sayakpaul
91cfffc54b check, 2024-08-10 18:55:24 +05:30
sayakpaul
632bf78c04 fix-copies 2024-08-10 18:50:39 +05:30
sayakpaul
48768b66b6 style 2024-08-10 08:34:40 +05:30
Sayak Paul
225634e8d1 Merge branch 'main' into feat-config-metadata 2024-08-10 08:14:03 +05:30
sayakpaul
6bbf629349 style 2024-08-10 08:13:28 +05:30
sayakpaul
167557b9bf fix-copues 2024-08-10 08:13:08 +05:30
sayakpaul
733e1d9259 documentation 2024-08-10 08:10:29 +05:30
sayakpaul
79aff1d82b fix lora_alpha 2024-08-09 20:15:47 +05:30
Sayak Paul
fb9e86d806 Apply suggestions from code review
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
2024-08-09 20:04:01 +05:30
sayakpaul
2b9f77e2d3 fix flux test 2024-08-09 19:58:29 +05:30
sayakpaul
1819000151 feat: add non-breaking support to serialize metadata in loras. 2024-08-09 18:41:56 +05:30
6 changed files with 368 additions and 55 deletions

View File

@@ -14,6 +14,7 @@
import copy import copy
import inspect import inspect
import json
import os import os
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
@@ -26,6 +27,7 @@ from huggingface_hub.constants import HF_HUB_OFFLINE
from ..models.modeling_utils import ModelMixin, load_state_dict from ..models.modeling_utils import ModelMixin, load_state_dict
from ..utils import ( from ..utils import (
SAFETENSORS_FILE_EXTENSION,
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
_get_model_file, _get_model_file,
delete_adapter_layers, delete_adapter_layers,
@@ -44,6 +46,7 @@ if is_transformers_available():
from transformers import PreTrainedModel from transformers import PreTrainedModel
if is_peft_available(): if is_peft_available():
from peft import LoraConfig
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
if is_accelerate_available(): if is_accelerate_available():
@@ -252,6 +255,7 @@ class LoraBaseMixin:
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
model_file = None model_file = None
metadata = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict): if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights # Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or ( if (use_safetensors and weight_name is None) or (
@@ -280,6 +284,8 @@ class LoraBaseMixin:
user_agent=user_agent, user_agent=user_agent,
) )
state_dict = safetensors.torch.load_file(model_file, device="cpu") state_dict = safetensors.torch.load_file(model_file, device="cpu")
with safetensors.safe_open(model_file, framework="pt", device="cpu") as f:
metadata = f.metadata()
except (IOError, safetensors.SafetensorError) as e: except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle: if not allow_pickle:
raise e raise e
@@ -305,10 +311,14 @@ class LoraBaseMixin:
user_agent=user_agent, user_agent=user_agent,
) )
state_dict = load_state_dict(model_file) state_dict = load_state_dict(model_file)
file_extension = os.path.basename(model_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
with safetensors.safe_open(model_file, framework="pt", device="cpu") as f:
metadata = f.metadata()
else: else:
state_dict = pretrained_model_name_or_path_or_dict state_dict = pretrained_model_name_or_path_or_dict
return state_dict return state_dict, metadata
@classmethod @classmethod
def _best_guess_weight_name( def _best_guess_weight_name(
@@ -709,6 +719,20 @@ class LoraBaseMixin:
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict return layers_state_dict
@staticmethod
def pack_metadata(config, prefix):
local_metadata = {}
if config is not None:
if isinstance(config, LoraConfig):
config = config.to_dict()
for key, value in config.items():
if isinstance(value, set):
config[key] = list(value)
config_as_string = json.dumps(config, indent=2, sort_keys=True)
local_metadata[prefix] = config_as_string
return local_metadata
@staticmethod @staticmethod
def write_lora_layers( def write_lora_layers(
state_dict: Dict[str, torch.Tensor], state_dict: Dict[str, torch.Tensor],
@@ -717,9 +741,13 @@ class LoraBaseMixin:
weight_name: str, weight_name: str,
save_function: Callable, save_function: Callable,
safe_serialization: bool, safe_serialization: bool,
metadata=None,
): ):
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
if not safe_serialization and isinstance(metadata, dict) and len(metadata) > 0:
raise ValueError("Passing `metadata` is not possible when `safe_serialization` is False.")
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return return
@@ -727,8 +755,12 @@ class LoraBaseMixin:
if save_function is None: if save_function is None:
if safe_serialization: if safe_serialization:
def save_function(weights, filename): def save_function(weights, filename, metadata):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) if metadata is None:
metadata = {"format": "pt"}
elif len(metadata) > 0:
metadata.update({"format": "pt"})
return safetensors.torch.save_file(weights, filename, metadata=metadata)
else: else:
save_function = torch.save save_function = torch.save
@@ -742,6 +774,9 @@ class LoraBaseMixin:
weight_name = LORA_WEIGHT_NAME weight_name = LORA_WEIGHT_NAME
save_path = Path(save_directory, weight_name).as_posix() save_path = Path(save_directory, weight_name).as_posix()
if save_function != torch.save:
save_function(state_dict, save_path, metadata)
else:
save_function(state_dict, save_path) save_function(state_dict, save_path)
logger.info(f"Model weights saved in {save_path}") logger.info(f"Model weights saved in {save_path}")

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import os import os
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
@@ -92,7 +93,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) state_dict, network_alphas, metadata = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, **kwargs, return_metadata=True
)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
@@ -104,6 +107,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
adapter_name=adapter_name, adapter_name=adapter_name,
_pipeline=self, _pipeline=self,
config=metadata,
) )
self.load_lora_into_text_encoder( self.load_lora_into_text_encoder(
state_dict, state_dict,
@@ -113,6 +117,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
else self.text_encoder, else self.text_encoder,
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
adapter_name=adapter_name, adapter_name=adapter_name,
config=metadata,
_pipeline=self, _pipeline=self,
) )
@@ -168,6 +173,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
weight_name (`str`, *optional*, defaults to None): weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file. Name of the serialized state dict file.
return_metadata (`bool`):
If state dict metadata should be returned. Is only supported when the state dict has a safetensors
extension.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both. # UNet and text encoder or both.
@@ -181,6 +189,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
unet_config = kwargs.pop("unet_config", None) unet_config = kwargs.pop("unet_config", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_metadata = kwargs.pop("return_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
@@ -192,7 +201,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch", "framework": "pytorch",
} }
state_dict = cls._fetch_state_dict( state_dict, metadata = cls._fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
@@ -224,10 +233,13 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
if return_metadata:
return state_dict, network_alphas, metadata
else:
return state_dict, network_alphas return state_dict, network_alphas
@classmethod @classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, config=None, _pipeline=None):
""" """
This will load the LoRA layers specified in `state_dict` into `unet`. This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -245,6 +257,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
@@ -258,7 +271,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
# Load the layers corresponding to UNet. # Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.") logger.info(f"Loading {cls.unet_name}.")
unet.load_attn_procs( unet.load_attn_procs(
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
config=config,
_pipeline=_pipeline,
) )
@classmethod @classmethod
@@ -270,6 +287,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
prefix=None, prefix=None,
lora_scale=1.0, lora_scale=1.0,
adapter_name=None, adapter_name=None,
config=None,
_pipeline=None, _pipeline=None,
): ):
""" """
@@ -291,6 +309,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
@@ -303,6 +322,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
keys = list(state_dict.keys()) keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix prefix = cls.text_encoder_name if prefix is None else prefix
if config is not None and len(config) > 0:
config = json.loads(config[prefix])
# Safe prefix to check with. # Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys): if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments. # Load the layers corresponding to text encoder and make necessary adjustments.
@@ -341,7 +363,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
} }
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
)
if "use_dora" in lora_config_kwargs: if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]: if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"): if is_peft_version("<", "0.9.0"):
@@ -385,6 +409,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
unet_lora_config: dict = None,
text_encoder_lora_config: dict = None,
is_main_process: bool = True, is_main_process: bool = True,
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
@@ -401,6 +427,10 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers. encoder LoRA state dict because it comes from 🤗 Transformers.
unet_lora_config (`dict`, *optional*):
LoRA configuration used to train the `unet_lora_layers`.
text_encoder_lora_config (`dict`, *optional*):
LoRA configuration used to train the `text_encoder_lora_layers`.
is_main_process (`bool`, *optional*, defaults to `True`): is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main need to call this function on all processes. In this case, set `is_main_process=True` only on the main
@@ -413,19 +443,27 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
""" """
state_dict = {} state_dict = {}
metadata = {}
if not (unet_lora_layers or text_encoder_lora_layers): if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.") raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
if unet_lora_layers: if unet_lora_layers:
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
if unet_lora_config:
unet_metadata = cls.pack_metadata(unet_lora_config, cls.unet_name)
metadata.update(unet_metadata)
if text_encoder_lora_layers: if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
if text_encoder_lora_config:
te_metadata = cls.pack_metadata(text_encoder_lora_config, cls.text_encoder_name)
metadata.update(te_metadata)
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
state_dict=state_dict, state_dict=state_dict,
metadata=metadata,
save_directory=save_directory, save_directory=save_directory,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
@@ -489,10 +527,6 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
Args: Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components)
@@ -550,9 +584,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict( state_dict, network_alphas, metadata = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config, unet_config=self.unet.config,
return_metadata=True,
**kwargs, **kwargs,
) )
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
@@ -560,7 +595,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
raise ValueError("Invalid LoRA checkpoint.") raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet( self.load_lora_into_unet(
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self state_dict,
network_alphas=network_alphas,
unet=self.unet,
adapter_name=adapter_name,
config=metadata,
_pipeline=self,
) )
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0: if len(text_encoder_state_dict) > 0:
@@ -571,6 +611,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
prefix="text_encoder", prefix="text_encoder",
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
adapter_name=adapter_name, adapter_name=adapter_name,
config=metadata,
_pipeline=self, _pipeline=self,
) )
@@ -583,6 +624,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
prefix="text_encoder_2", prefix="text_encoder_2",
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
adapter_name=adapter_name, adapter_name=adapter_name,
config=metadata,
_pipeline=self, _pipeline=self,
) )
@@ -639,6 +681,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
weight_name (`str`, *optional*, defaults to None): weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file. Name of the serialized state dict file.
return_metadata (`bool`):
If state dict metadata should be returned. Is only supported when the state dict has a safetensors
extension.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both. # UNet and text encoder or both.
@@ -652,6 +697,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
unet_config = kwargs.pop("unet_config", None) unet_config = kwargs.pop("unet_config", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_metadata = kwargs.pop("return_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
@@ -663,7 +709,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch", "framework": "pytorch",
} }
state_dict = cls._fetch_state_dict( state_dict, metadata = cls._fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
@@ -695,11 +741,14 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
if return_metadata:
return state_dict, network_alphas, metadata
else:
return state_dict, network_alphas return state_dict, network_alphas
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, config=None, _pipeline=None):
""" """
This will load the LoRA layers specified in `state_dict` into `unet`. This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -717,6 +766,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
@@ -730,7 +780,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
# Load the layers corresponding to UNet. # Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.") logger.info(f"Loading {cls.unet_name}.")
unet.load_attn_procs( unet.load_attn_procs(
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
config=config,
_pipeline=_pipeline,
) )
@classmethod @classmethod
@@ -743,6 +797,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
prefix=None, prefix=None,
lora_scale=1.0, lora_scale=1.0,
adapter_name=None, adapter_name=None,
config=None,
_pipeline=None, _pipeline=None,
): ):
""" """
@@ -764,6 +819,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
@@ -776,6 +832,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
keys = list(state_dict.keys()) keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix prefix = cls.text_encoder_name if prefix is None else prefix
if config is not None and len(config) > 0:
config = json.loads(config[prefix])
# Safe prefix to check with. # Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys): if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments. # Load the layers corresponding to text encoder and make necessary adjustments.
@@ -814,7 +873,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
} }
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
)
if "use_dora" in lora_config_kwargs: if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]: if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"): if is_peft_version("<", "0.9.0"):
@@ -859,6 +920,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
unet_lora_config: dict = None,
text_encoder_lora_config: dict = None,
text_encoder_2_lora_config: dict = None,
is_main_process: bool = True, is_main_process: bool = True,
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
@@ -878,6 +942,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers. encoder LoRA state dict because it comes from 🤗 Transformers.
unet_lora_config (`dict`, *optional*):
LoRA configuration used to train the `unet_lora_layers`.
text_encoder_lora_config (`dict`, *optional*):
LoRA configuration used to train the `text_encoder_lora_layers`.
text_encoder_2_lora_config (`dict`, *optional*):
LoRA configuration used to train the `text_encoder_2_lora_layers`.
is_main_process (`bool`, *optional*, defaults to `True`): is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main need to call this function on all processes. In this case, set `is_main_process=True` only on the main
@@ -890,6 +960,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
""" """
state_dict = {} state_dict = {}
metadata = {}
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError( raise ValueError(
@@ -898,15 +969,25 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
if unet_lora_layers: if unet_lora_layers:
state_dict.update(cls.pack_weights(unet_lora_layers, "unet")) state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
if unet_lora_config is not None:
unet_metadata = cls.pack_metadata(unet_lora_config, "unet")
metadata.update(unet_metadata)
if text_encoder_lora_layers: if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
if text_encoder_lora_config is not None:
te_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder")
metadata.update(te_metadata)
if text_encoder_2_lora_layers: if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
if text_encoder_2_lora_config is not None:
te2_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder_2")
metadata.update(te2_metadata)
cls.write_lora_layers( cls.write_lora_layers(
state_dict=state_dict, state_dict=state_dict,
metadata=metadata,
save_directory=save_directory, save_directory=save_directory,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
@@ -970,10 +1051,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
Args: Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components)
@@ -1041,6 +1118,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
return_metadata (`bool`, *optional*):
If state dict metadata should be returned. Is only supported when the state dict has a safetensors
extension.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
@@ -1054,6 +1134,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_metadata = kwargs.pop("return_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
@@ -1065,7 +1146,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch", "framework": "pytorch",
} }
state_dict = cls._fetch_state_dict( state_dict, metadata = cls._fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
@@ -1080,13 +1161,17 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
allow_pickle=allow_pickle, allow_pickle=allow_pickle,
) )
# Otherwise, this would be a breaking change.
if return_metadata:
return state_dict, metadata
else:
return state_dict return state_dict
def load_lora_weights( def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
): ):
""" """
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
`self.text_encoder`. `self.text_encoder`.
All kwargs are forwarded to `self.lora_state_dict`. All kwargs are forwarded to `self.lora_state_dict`.
@@ -1114,7 +1199,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) state_dict, metadata = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, **kwargs, return_metadata=True
)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
@@ -1124,6 +1211,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
config=metadata,
_pipeline=self, _pipeline=self,
) )
@@ -1136,6 +1224,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
prefix="text_encoder", prefix="text_encoder",
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
adapter_name=adapter_name, adapter_name=adapter_name,
config=metadata,
_pipeline=self, _pipeline=self,
) )
@@ -1148,11 +1237,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
prefix="text_encoder_2", prefix="text_encoder_2",
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
adapter_name=adapter_name, adapter_name=adapter_name,
config=metadata,
_pipeline=self, _pipeline=self,
) )
@classmethod @classmethod
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, config=None, _pipeline=None):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1166,6 +1256,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
config (`dict`): Configuration that was used to train this LoRA.
""" """
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
@@ -1192,7 +1283,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
if "lora_B" in key: if "lora_B" in key:
rank[key] = val.shape[1] rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) if config is not None and isinstance(config, dict) and len(config) > 0:
config = json.loads(config[cls.transformer_name])
lora_config_kwargs = get_peft_kwargs(
rank, network_alpha_dict=None, config=config, peft_state_dict=state_dict
)
if "use_dora" in lora_config_kwargs: if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError( raise ValueError(
@@ -1239,6 +1334,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
prefix=None, prefix=None,
lora_scale=1.0, lora_scale=1.0,
adapter_name=None, adapter_name=None,
config=None,
_pipeline=None, _pipeline=None,
): ):
""" """
@@ -1260,6 +1356,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
@@ -1272,6 +1369,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
keys = list(state_dict.keys()) keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix prefix = cls.text_encoder_name if prefix is None else prefix
if config is not None and len(config) > 0:
config = json.loads(config[prefix])
# Safe prefix to check with. # Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys): if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments. # Load the layers corresponding to text encoder and make necessary adjustments.
@@ -1310,7 +1410,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
} }
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
)
if "use_dora" in lora_config_kwargs: if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]: if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"): if is_peft_version("<", "0.9.0"):
@@ -1349,12 +1451,16 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
# Unsafe code /> # Unsafe code />
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
def save_lora_weights( def save_lora_weights(
cls, cls,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, torch.nn.Module] = None, transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
transformer_lora_config: dict = None,
text_encoder_lora_config: dict = None,
text_encoder_2_lora_config: dict = None,
is_main_process: bool = True, is_main_process: bool = True,
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
@@ -1374,6 +1480,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers. encoder LoRA state dict because it comes from 🤗 Transformers.
transformer_lora_config (`dict`, *optional*):
LoRA configuration used to train the `transformer_lora_layers`.
text_encoder_lora_config (`dict`, *optional*):
LoRA configuration used to train the `text_encoder_lora_layers`.
text_encoder_2_lora_config (`dict`, *optional*):
LoRA configuration used to train the `text_encoder_2_lora_layers`.
is_main_process (`bool`, *optional*, defaults to `True`): is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main need to call this function on all processes. In this case, set `is_main_process=True` only on the main
@@ -1386,24 +1498,34 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
""" """
state_dict = {} state_dict = {}
metadata = {}
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError( raise ValueError(
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
) )
if transformer_lora_layers: if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) state_dict.update(cls.pack_weights(transformer_lora_layers, "transformer"))
if transformer_lora_config is not None:
transformer_metadata = cls.pack_metadata(transformer_lora_config, "transformer")
metadata.update(transformer_metadata)
if text_encoder_lora_layers: if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
if text_encoder_lora_config is not None:
te_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder")
metadata.update(te_metadata)
if text_encoder_2_lora_layers: if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
if text_encoder_2_lora_config is not None:
te2_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder_2")
metadata.update(te2_metadata)
# Save the model
cls.write_lora_layers( cls.write_lora_layers(
state_dict=state_dict, state_dict=state_dict,
metadata=metadata,
save_directory=save_directory, save_directory=save_directory,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
@@ -1411,6 +1533,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora( def fuse_lora(
self, self,
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
@@ -1454,6 +1577,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs): def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
r""" r"""
Reverses the effect of Reverses the effect of
@@ -1467,10 +1591,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
Args: Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components)
@@ -1538,6 +1658,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
return_metadata (`bool`, *optional*):
If state dict metadata should be returned. Is only supported when the state dict has a safetensors
extension.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
@@ -1551,6 +1674,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_metadata = kwargs.pop("return_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
@@ -1562,7 +1686,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch", "framework": "pytorch",
} }
state_dict = cls._fetch_state_dict( state_dict, metadata = cls._fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
@@ -1577,6 +1701,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
allow_pickle=allow_pickle, allow_pickle=allow_pickle,
) )
# Otherwise, this would be a breaking change.
if return_metadata:
return state_dict, metadata
else:
return state_dict return state_dict
def load_lora_weights( def load_lora_weights(
@@ -1611,7 +1739,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) state_dict, metadata = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, **kwargs, return_metadata=True
)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
@@ -1621,6 +1751,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
config=metadata,
_pipeline=self, _pipeline=self,
) )
@@ -1633,12 +1764,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
prefix="text_encoder", prefix="text_encoder",
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
adapter_name=adapter_name, adapter_name=adapter_name,
config=metadata,
_pipeline=self, _pipeline=self,
) )
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, config=None, _pipeline=None):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1652,6 +1784,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
config (`dict`): Configuration that was used to train this LoRA.
""" """
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
@@ -1678,7 +1811,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if "lora_B" in key: if "lora_B" in key:
rank[key] = val.shape[1] rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) if config is not None and isinstance(config, dict) and len(config) > 0:
config = json.loads(config[cls.transformer_name])
lora_config_kwargs = get_peft_kwargs(
rank, network_alpha_dict=None, config=config, peft_state_dict=state_dict
)
if "use_dora" in lora_config_kwargs: if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError( raise ValueError(
@@ -1725,6 +1862,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
prefix=None, prefix=None,
lora_scale=1.0, lora_scale=1.0,
adapter_name=None, adapter_name=None,
config=None,
_pipeline=None, _pipeline=None,
): ):
""" """
@@ -1746,6 +1884,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
@@ -1758,6 +1897,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
keys = list(state_dict.keys()) keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix prefix = cls.text_encoder_name if prefix is None else prefix
if config is not None and len(config) > 0:
config = json.loads(config[prefix])
# Safe prefix to check with. # Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys): if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments. # Load the layers corresponding to text encoder and make necessary adjustments.
@@ -1796,7 +1938,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
} }
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
)
if "use_dora" in lora_config_kwargs: if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]: if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"): if is_peft_version("<", "0.9.0"):
@@ -1841,6 +1985,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
transformer_lora_config: dict = None,
text_encoder_lora_config: dict = None,
is_main_process: bool = True, is_main_process: bool = True,
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
@@ -1857,6 +2003,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers. encoder LoRA state dict because it comes from 🤗 Transformers.
transformer_lora_config (`dict`, *optional*):
LoRA configuration used to train the `transformer_lora_layers`.
text_encoder_lora_config (`dict`, *optional*):
LoRA configuration used to train the `text_encoder_lora_layers`.
is_main_process (`bool`, *optional*, defaults to `True`): is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main need to call this function on all processes. In this case, set `is_main_process=True` only on the main
@@ -1869,19 +2019,27 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
""" """
state_dict = {} state_dict = {}
metadata = {}
if not (transformer_lora_layers or text_encoder_lora_layers): if not (transformer_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
if transformer_lora_layers: if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_config:
transformer_metadata = cls.pack_metadata(transformer_lora_config, cls.transformer_name)
metadata.update(transformer_metadata)
if text_encoder_lora_layers: if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
if text_encoder_lora_config:
te_metadata = cls.pack_metadata(text_encoder_lora_config, cls.text_encoder_name)
metadata.update(te_metadata)
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
state_dict=state_dict, state_dict=state_dict,
metadata=metadata,
save_directory=save_directory, save_directory=save_directory,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
@@ -1933,6 +2091,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r""" r"""
Reverses the effect of Reverses the effect of
@@ -2051,6 +2210,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
prefix=None, prefix=None,
lora_scale=1.0, lora_scale=1.0,
adapter_name=None, adapter_name=None,
config=None,
_pipeline=None, _pipeline=None,
): ):
""" """
@@ -2072,6 +2232,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
@@ -2084,6 +2245,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
keys = list(state_dict.keys()) keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix prefix = cls.text_encoder_name if prefix is None else prefix
if config is not None and len(config) > 0:
config = json.loads(config[prefix])
# Safe prefix to check with. # Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys): if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments. # Load the layers corresponding to text encoder and make necessary adjustments.
@@ -2122,7 +2286,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
} }
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
)
if "use_dora" in lora_config_kwargs: if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]: if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"): if is_peft_version("<", "0.9.0"):

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import os import os
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
@@ -115,6 +116,7 @@ class UNet2DConditionLoadersMixin:
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
weight_name (`str`, *optional*, defaults to None): weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file. Name of the serialized state dict file.
config: (`dict`, *optional*)
Example: Example:
@@ -143,6 +145,7 @@ class UNet2DConditionLoadersMixin:
_pipeline = kwargs.pop("_pipeline", None) _pipeline = kwargs.pop("_pipeline", None)
network_alphas = kwargs.pop("network_alphas", None) network_alphas = kwargs.pop("network_alphas", None)
allow_pickle = False allow_pickle = False
config = kwargs.pop("config", None)
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
@@ -208,6 +211,7 @@ class UNet2DConditionLoadersMixin:
unet_identifier_key=self.unet_name, unet_identifier_key=self.unet_name,
network_alphas=network_alphas, network_alphas=network_alphas,
adapter_name=adapter_name, adapter_name=adapter_name,
config=config,
_pipeline=_pipeline, _pipeline=_pipeline,
) )
else: else:
@@ -268,7 +272,7 @@ class UNet2DConditionLoadersMixin:
return attn_processors return attn_processors
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline): def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, config=None):
# This method does the following things: # This method does the following things:
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
# format. For legacy format no filtering is applied. # format. For legacy format no filtering is applied.
@@ -316,7 +320,10 @@ class UNet2DConditionLoadersMixin:
if "lora_B" in key: if "lora_B" in key:
rank[key] = val.shape[1] rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) if config is not None and isinstance(config, dict) and len(config) > 0:
config = json.loads(config["unet"])
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True)
if "use_dora" in lora_config_kwargs: if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]: if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"): if is_peft_version("<", "0.9.0"):

View File

@@ -21,9 +21,12 @@ from typing import Optional
from packaging import version from packaging import version
from . import logging
from .import_utils import is_peft_available, is_torch_available from .import_utils import is_peft_available, is_torch_available
logger = logging.get_logger(__name__)
if is_torch_available(): if is_torch_available():
import torch import torch
@@ -147,11 +150,29 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
module.set_scale(adapter_name, 1.0) module.set_scale(adapter_name, 1.0)
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True): def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, is_unet=True):
rank_pattern = {} rank_pattern = {}
alpha_pattern = {} alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0] r = lora_alpha = list(rank_dict.values())[0]
# Try to retrieve config.
alpha_retrieved = False
if config is not None:
lora_alpha = config["lora_alpha"] if "lora_alpha" in config else lora_alpha
alpha_retrieved = True
# We simply ignore the `alpha_pattern` and `rank_pattern` if they are found
# in the `config`. This is because:
# 1. We determine `rank_pattern` from the `rank_dict`.
# 2. When `network_alpha_dict` is passed that means the underlying checkpoint
# is a non-diffusers checkpoint.
# More details: https://github.com/huggingface/diffusers/pull/9143#discussion_r1711491175
if config.get("alpha_pattern", None) is not None:
logger.warning("`alpha_pattern` found in the LoRA config. This will be ignored.")
if config.get("rank_pattern", None) is not None:
logger.warning("`rank_pattern` found in the LoRA config. This will be ignored.")
if len(set(rank_dict.values())) > 1: if len(set(rank_dict.values())) > 1:
# get the rank occuring the most number of times # get the rank occuring the most number of times
r = collections.Counter(rank_dict.values()).most_common()[0][0] r = collections.Counter(rank_dict.values()).most_common()[0][0]
@@ -160,7 +181,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
if network_alpha_dict is not None and len(network_alpha_dict) > 0: if not alpha_retrieved and network_alpha_dict is not None and len(network_alpha_dict) > 0:
if len(set(network_alpha_dict.values())) > 1: if len(set(network_alpha_dict.values())) > 1:
# get the alpha occuring the most number of times # get the alpha occuring the most number of times
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]

View File

@@ -19,7 +19,7 @@ import torch
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, torch_device
sys.path.append(".") sys.path.append(".")
@@ -28,6 +28,7 @@ from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.")
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_cls = FlowMatchEulerDiscreteScheduler()

View File

@@ -87,7 +87,7 @@ class PeftLoraLoaderMixinTests:
transformer_kwargs = None transformer_kwargs = None
vae_kwargs = None vae_kwargs = None
def get_dummy_components(self, scheduler_cls=None, use_dora=False): def get_dummy_components(self, scheduler_cls=None, use_dora=False, alpha=None):
if self.unet_kwargs and self.transformer_kwargs: if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
if self.has_two_text_encoders and self.has_three_text_encoders: if self.has_two_text_encoders and self.has_three_text_encoders:
@@ -95,6 +95,7 @@ class PeftLoraLoaderMixinTests:
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
rank = 4 rank = 4
alpha = rank if alpha is None else alpha
torch.manual_seed(0) torch.manual_seed(0)
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
@@ -120,7 +121,7 @@ class PeftLoraLoaderMixinTests:
text_lora_config = LoraConfig( text_lora_config = LoraConfig(
r=rank, r=rank,
lora_alpha=rank, lora_alpha=alpha,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
init_lora_weights=False, init_lora_weights=False,
use_dora=use_dora, use_dora=use_dora,
@@ -128,7 +129,7 @@ class PeftLoraLoaderMixinTests:
denoiser_lora_config = LoraConfig( denoiser_lora_config = LoraConfig(
r=rank, r=rank,
lora_alpha=rank, lora_alpha=alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"], target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False, init_lora_weights=False,
use_dora=use_dora, use_dora=use_dora,
@@ -1752,3 +1753,85 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs() _, _, inputs = self.get_dummy_inputs()
_ = pipe(**inputs).images _ = pipe(**inputs).images
def test_if_lora_alpha_is_correctly_parsed(self):
lora_alpha = 8
scheduler_class = FlowMatchEulerDiscreteScheduler if self.uses_flow_matching else DDIMScheduler
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
alpha=lora_alpha, scheduler_cls=scheduler_class
)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
# Inference works?
_ = pipe(**inputs, generator=torch.manual_seed(0)).images
with tempfile.TemporaryDirectory() as tmpdirname:
denoiser = pipe.unet if self.unet_kwargs else pipe.transformer
denoiser_state_dict = get_peft_model_state_dict(denoiser)
denoiser_lora_config = denoiser.peft_config["default"]
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
text_encoder_lora_config = pipe.text_encoder.peft_config["default"]
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
text_encoder_2_lora_config = pipe.text_encoder_2.peft_config["default"]
saving_kwargs = {
"save_directory": tmpdirname,
"text_encoder_lora_layers": text_encoder_state_dict,
"text_encoder_lora_config": text_encoder_lora_config,
}
if self.unet_kwargs is not None:
saving_kwargs.update(
{"unet_lora_layers": denoiser_state_dict, "unet_lora_config": denoiser_lora_config}
)
else:
saving_kwargs.update(
{"transformer_lora_layers": denoiser_state_dict, "transformer_lora_config": denoiser_lora_config}
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
saving_kwargs.update(
{
"text_encoder_2_lora_layers": text_encoder_2_state_dict,
"text_encoder_2_lora_config": text_encoder_2_lora_config,
}
)
self.pipeline_class.save_lora_weights(**saving_kwargs)
loaded_pipe = self.pipeline_class(**components)
loaded_pipe.load_lora_weights(tmpdirname)
# Inference works?
_ = loaded_pipe(**inputs, generator=torch.manual_seed(0)).images
denoiser_loaded = pipe.unet if self.unet_kwargs is not None else pipe.transformer
assert (
denoiser_loaded.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for UNet."
assert (
loaded_pipe.text_encoder.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for text encoder."
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
assert (
loaded_pipe.text_encoder_2.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for text encoder 2."