mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-14 16:34:27 +08:00
Compare commits
19 Commits
onnx-cpu-d
...
feat-confi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d411fd5c2 | ||
|
|
178a4596b2 | ||
|
|
e7808e4c35 | ||
|
|
d79c0d5970 | ||
|
|
f7d30de31a | ||
|
|
712a110e76 | ||
|
|
1852c3fa3b | ||
|
|
6fb5987e40 | ||
|
|
91cfffc54b | ||
|
|
632bf78c04 | ||
|
|
48768b66b6 | ||
|
|
225634e8d1 | ||
|
|
6bbf629349 | ||
|
|
167557b9bf | ||
|
|
733e1d9259 | ||
|
|
79aff1d82b | ||
|
|
fb9e86d806 | ||
|
|
2b9f77e2d3 | ||
|
|
1819000151 |
@@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, List, Optional, Union
|
from typing import Callable, Dict, List, Optional, Union
|
||||||
@@ -26,6 +27,7 @@ from huggingface_hub.constants import HF_HUB_OFFLINE
|
|||||||
|
|
||||||
from ..models.modeling_utils import ModelMixin, load_state_dict
|
from ..models.modeling_utils import ModelMixin, load_state_dict
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
|
SAFETENSORS_FILE_EXTENSION,
|
||||||
USE_PEFT_BACKEND,
|
USE_PEFT_BACKEND,
|
||||||
_get_model_file,
|
_get_model_file,
|
||||||
delete_adapter_layers,
|
delete_adapter_layers,
|
||||||
@@ -44,6 +46,7 @@ if is_transformers_available():
|
|||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
if is_peft_available():
|
if is_peft_available():
|
||||||
|
from peft import LoraConfig
|
||||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
@@ -252,6 +255,7 @@ class LoraBaseMixin:
|
|||||||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
||||||
|
|
||||||
model_file = None
|
model_file = None
|
||||||
|
metadata = None
|
||||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||||
# Let's first try to load .safetensors weights
|
# Let's first try to load .safetensors weights
|
||||||
if (use_safetensors and weight_name is None) or (
|
if (use_safetensors and weight_name is None) or (
|
||||||
@@ -280,6 +284,8 @@ class LoraBaseMixin:
|
|||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||||
|
with safetensors.safe_open(model_file, framework="pt", device="cpu") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
except (IOError, safetensors.SafetensorError) as e:
|
except (IOError, safetensors.SafetensorError) as e:
|
||||||
if not allow_pickle:
|
if not allow_pickle:
|
||||||
raise e
|
raise e
|
||||||
@@ -305,10 +311,14 @@ class LoraBaseMixin:
|
|||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
state_dict = load_state_dict(model_file)
|
state_dict = load_state_dict(model_file)
|
||||||
|
file_extension = os.path.basename(model_file).split(".")[-1]
|
||||||
|
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
||||||
|
with safetensors.safe_open(model_file, framework="pt", device="cpu") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
else:
|
else:
|
||||||
state_dict = pretrained_model_name_or_path_or_dict
|
state_dict = pretrained_model_name_or_path_or_dict
|
||||||
|
|
||||||
return state_dict
|
return state_dict, metadata
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _best_guess_weight_name(
|
def _best_guess_weight_name(
|
||||||
@@ -709,6 +719,20 @@ class LoraBaseMixin:
|
|||||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||||
return layers_state_dict
|
return layers_state_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pack_metadata(config, prefix):
|
||||||
|
local_metadata = {}
|
||||||
|
if config is not None:
|
||||||
|
if isinstance(config, LoraConfig):
|
||||||
|
config = config.to_dict()
|
||||||
|
for key, value in config.items():
|
||||||
|
if isinstance(value, set):
|
||||||
|
config[key] = list(value)
|
||||||
|
|
||||||
|
config_as_string = json.dumps(config, indent=2, sort_keys=True)
|
||||||
|
local_metadata[prefix] = config_as_string
|
||||||
|
return local_metadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_lora_layers(
|
def write_lora_layers(
|
||||||
state_dict: Dict[str, torch.Tensor],
|
state_dict: Dict[str, torch.Tensor],
|
||||||
@@ -717,9 +741,13 @@ class LoraBaseMixin:
|
|||||||
weight_name: str,
|
weight_name: str,
|
||||||
save_function: Callable,
|
save_function: Callable,
|
||||||
safe_serialization: bool,
|
safe_serialization: bool,
|
||||||
|
metadata=None,
|
||||||
):
|
):
|
||||||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
||||||
|
|
||||||
|
if not safe_serialization and isinstance(metadata, dict) and len(metadata) > 0:
|
||||||
|
raise ValueError("Passing `metadata` is not possible when `safe_serialization` is False.")
|
||||||
|
|
||||||
if os.path.isfile(save_directory):
|
if os.path.isfile(save_directory):
|
||||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
@@ -727,8 +755,12 @@ class LoraBaseMixin:
|
|||||||
if save_function is None:
|
if save_function is None:
|
||||||
if safe_serialization:
|
if safe_serialization:
|
||||||
|
|
||||||
def save_function(weights, filename):
|
def save_function(weights, filename, metadata):
|
||||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
if metadata is None:
|
||||||
|
metadata = {"format": "pt"}
|
||||||
|
elif len(metadata) > 0:
|
||||||
|
metadata.update({"format": "pt"})
|
||||||
|
return safetensors.torch.save_file(weights, filename, metadata=metadata)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
save_function = torch.save
|
save_function = torch.save
|
||||||
@@ -742,7 +774,10 @@ class LoraBaseMixin:
|
|||||||
weight_name = LORA_WEIGHT_NAME
|
weight_name = LORA_WEIGHT_NAME
|
||||||
|
|
||||||
save_path = Path(save_directory, weight_name).as_posix()
|
save_path = Path(save_directory, weight_name).as_posix()
|
||||||
save_function(state_dict, save_path)
|
if save_function != torch.save:
|
||||||
|
save_function(state_dict, save_path, metadata)
|
||||||
|
else:
|
||||||
|
save_function(state_dict, save_path)
|
||||||
logger.info(f"Model weights saved in {save_path}")
|
logger.info(f"Model weights saved in {save_path}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Dict, List, Optional, Union
|
from typing import Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
@@ -92,7 +93,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|||||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||||
|
|
||||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
state_dict, network_alphas, metadata = self.lora_state_dict(
|
||||||
|
pretrained_model_name_or_path_or_dict, **kwargs, return_metadata=True
|
||||||
|
)
|
||||||
|
|
||||||
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
||||||
if not is_correct_format:
|
if not is_correct_format:
|
||||||
@@ -104,6 +107,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|||||||
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
||||||
adapter_name=adapter_name,
|
adapter_name=adapter_name,
|
||||||
_pipeline=self,
|
_pipeline=self,
|
||||||
|
config=metadata,
|
||||||
)
|
)
|
||||||
self.load_lora_into_text_encoder(
|
self.load_lora_into_text_encoder(
|
||||||
state_dict,
|
state_dict,
|
||||||
@@ -113,6 +117,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|||||||
else self.text_encoder,
|
else self.text_encoder,
|
||||||
lora_scale=self.lora_scale,
|
lora_scale=self.lora_scale,
|
||||||
adapter_name=adapter_name,
|
adapter_name=adapter_name,
|
||||||
|
config=metadata,
|
||||||
_pipeline=self,
|
_pipeline=self,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -168,6 +173,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|||||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||||
weight_name (`str`, *optional*, defaults to None):
|
weight_name (`str`, *optional*, defaults to None):
|
||||||
Name of the serialized state dict file.
|
Name of the serialized state dict file.
|
||||||
|
return_metadata (`bool`):
|
||||||
|
If state dict metadata should be returned. Is only supported when the state dict has a safetensors
|
||||||
|
extension.
|
||||||
"""
|
"""
|
||||||
# Load the main state dict first which has the LoRA layers for either of
|
# Load the main state dict first which has the LoRA layers for either of
|
||||||
# UNet and text encoder or both.
|
# UNet and text encoder or both.
|
||||||
@@ -181,6 +189,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|||||||
weight_name = kwargs.pop("weight_name", None)
|
weight_name = kwargs.pop("weight_name", None)
|
||||||
unet_config = kwargs.pop("unet_config", None)
|
unet_config = kwargs.pop("unet_config", None)
|
||||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||||
|
return_metadata = kwargs.pop("return_metadata", False)
|
||||||
|
|
||||||
allow_pickle = False
|
allow_pickle = False
|
||||||
if use_safetensors is None:
|
if use_safetensors is None:
|
||||||
@@ -192,7 +201,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|||||||
"framework": "pytorch",
|
"framework": "pytorch",
|
||||||
}
|
}
|
||||||
|
|
||||||
state_dict = cls._fetch_state_dict(
|
state_dict, metadata = cls._fetch_state_dict(
|
||||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||||
weight_name=weight_name,
|
weight_name=weight_name,
|
||||||
use_safetensors=use_safetensors,
|
use_safetensors=use_safetensors,
|
||||||
@@ -224,10 +233,13 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|||||||
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
||||||
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
|
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
|
||||||
|
|
||||||
return state_dict, network_alphas
|
if return_metadata:
|
||||||
|
return state_dict, network_alphas, metadata
|
||||||
|
else:
|
||||||
|
return state_dict, network_alphas
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
|
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, config=None, _pipeline=None):
|
||||||
"""
|
"""
|
||||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||||
|
|
||||||
@@ -245,6 +257,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|||||||
adapter_name (`str`, *optional*):
|
adapter_name (`str`, *optional*):
|
||||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||||
`default_{i}` where i is the total number of adapters being loaded.
|
`default_{i}` where i is the total number of adapters being loaded.
|
||||||
|
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
|
||||||
"""
|
"""
|
||||||
if not USE_PEFT_BACKEND:
|
if not USE_PEFT_BACKEND:
|
||||||
raise ValueError("PEFT backend is required for this method.")
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
@@ -258,7 +271,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|||||||
# Load the layers corresponding to UNet.
|
# Load the layers corresponding to UNet.
|
||||||
logger.info(f"Loading {cls.unet_name}.")
|
logger.info(f"Loading {cls.unet_name}.")
|
||||||
unet.load_attn_procs(
|
unet.load_attn_procs(
|
||||||
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
|
state_dict,
|
||||||
|
network_alphas=network_alphas,
|
||||||
|
adapter_name=adapter_name,
|
||||||
|
config=config,
|
||||||
|
_pipeline=_pipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -270,6 +287,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|||||||
prefix=None,
|
prefix=None,
|
||||||
lora_scale=1.0,
|
lora_scale=1.0,
|
||||||
adapter_name=None,
|
adapter_name=None,
|
||||||
|
config=None,
|
||||||
_pipeline=None,
|
_pipeline=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -291,6 +309,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|||||||
adapter_name (`str`, *optional*):
|
adapter_name (`str`, *optional*):
|
||||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||||
`default_{i}` where i is the total number of adapters being loaded.
|
`default_{i}` where i is the total number of adapters being loaded.
|
||||||
|
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
|
||||||
"""
|
"""
|
||||||
if not USE_PEFT_BACKEND:
|
if not USE_PEFT_BACKEND:
|
||||||
raise ValueError("PEFT backend is required for this method.")
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
@@ -303,6 +322,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|||||||
keys = list(state_dict.keys())
|
keys = list(state_dict.keys())
|
||||||
prefix = cls.text_encoder_name if prefix is None else prefix
|
prefix = cls.text_encoder_name if prefix is None else prefix
|
||||||
|
|
||||||
|
if config is not None and len(config) > 0:
|
||||||
|
config = json.loads(config[prefix])
|
||||||
|
|
||||||
# Safe prefix to check with.
|
# Safe prefix to check with.
|
||||||
if any(cls.text_encoder_name in key for key in keys):
|
if any(cls.text_encoder_name in key for key in keys):
|
||||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||||
@@ -341,7 +363,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|||||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
lora_config_kwargs = get_peft_kwargs(
|
||||||
|
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
|
||||||
|
)
|
||||||
if "use_dora" in lora_config_kwargs:
|
if "use_dora" in lora_config_kwargs:
|
||||||
if lora_config_kwargs["use_dora"]:
|
if lora_config_kwargs["use_dora"]:
|
||||||
if is_peft_version("<", "0.9.0"):
|
if is_peft_version("<", "0.9.0"):
|
||||||
@@ -385,6 +409,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|||||||
save_directory: Union[str, os.PathLike],
|
save_directory: Union[str, os.PathLike],
|
||||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||||
|
unet_lora_config: dict = None,
|
||||||
|
text_encoder_lora_config: dict = None,
|
||||||
is_main_process: bool = True,
|
is_main_process: bool = True,
|
||||||
weight_name: str = None,
|
weight_name: str = None,
|
||||||
save_function: Callable = None,
|
save_function: Callable = None,
|
||||||
@@ -401,6 +427,10 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|||||||
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||||
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
||||||
encoder LoRA state dict because it comes from 🤗 Transformers.
|
encoder LoRA state dict because it comes from 🤗 Transformers.
|
||||||
|
unet_lora_config (`dict`, *optional*):
|
||||||
|
LoRA configuration used to train the `unet_lora_layers`.
|
||||||
|
text_encoder_lora_config (`dict`, *optional*):
|
||||||
|
LoRA configuration used to train the `text_encoder_lora_layers`.
|
||||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||||
@@ -413,19 +443,27 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|||||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||||
"""
|
"""
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
if not (unet_lora_layers or text_encoder_lora_layers):
|
if not (unet_lora_layers or text_encoder_lora_layers):
|
||||||
raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
|
raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
|
||||||
|
|
||||||
if unet_lora_layers:
|
if unet_lora_layers:
|
||||||
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
|
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
|
||||||
|
if unet_lora_config:
|
||||||
|
unet_metadata = cls.pack_metadata(unet_lora_config, cls.unet_name)
|
||||||
|
metadata.update(unet_metadata)
|
||||||
|
|
||||||
if text_encoder_lora_layers:
|
if text_encoder_lora_layers:
|
||||||
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
||||||
|
if text_encoder_lora_config:
|
||||||
|
te_metadata = cls.pack_metadata(text_encoder_lora_config, cls.text_encoder_name)
|
||||||
|
metadata.update(te_metadata)
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
cls.write_lora_layers(
|
cls.write_lora_layers(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
|
metadata=metadata,
|
||||||
save_directory=save_directory,
|
save_directory=save_directory,
|
||||||
is_main_process=is_main_process,
|
is_main_process=is_main_process,
|
||||||
weight_name=weight_name,
|
weight_name=weight_name,
|
||||||
@@ -489,10 +527,6 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
|
||||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
|
||||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
|
||||||
LoRA parameters then it won't have any effect.
|
|
||||||
"""
|
"""
|
||||||
super().unfuse_lora(components=components)
|
super().unfuse_lora(components=components)
|
||||||
|
|
||||||
@@ -550,9 +584,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||||
|
|
||||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||||
state_dict, network_alphas = self.lora_state_dict(
|
state_dict, network_alphas, metadata = self.lora_state_dict(
|
||||||
pretrained_model_name_or_path_or_dict,
|
pretrained_model_name_or_path_or_dict,
|
||||||
unet_config=self.unet.config,
|
unet_config=self.unet.config,
|
||||||
|
return_metadata=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
||||||
@@ -560,7 +595,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
raise ValueError("Invalid LoRA checkpoint.")
|
raise ValueError("Invalid LoRA checkpoint.")
|
||||||
|
|
||||||
self.load_lora_into_unet(
|
self.load_lora_into_unet(
|
||||||
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
|
state_dict,
|
||||||
|
network_alphas=network_alphas,
|
||||||
|
unet=self.unet,
|
||||||
|
adapter_name=adapter_name,
|
||||||
|
config=metadata,
|
||||||
|
_pipeline=self,
|
||||||
)
|
)
|
||||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||||
if len(text_encoder_state_dict) > 0:
|
if len(text_encoder_state_dict) > 0:
|
||||||
@@ -571,6 +611,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
prefix="text_encoder",
|
prefix="text_encoder",
|
||||||
lora_scale=self.lora_scale,
|
lora_scale=self.lora_scale,
|
||||||
adapter_name=adapter_name,
|
adapter_name=adapter_name,
|
||||||
|
config=metadata,
|
||||||
_pipeline=self,
|
_pipeline=self,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -583,6 +624,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
prefix="text_encoder_2",
|
prefix="text_encoder_2",
|
||||||
lora_scale=self.lora_scale,
|
lora_scale=self.lora_scale,
|
||||||
adapter_name=adapter_name,
|
adapter_name=adapter_name,
|
||||||
|
config=metadata,
|
||||||
_pipeline=self,
|
_pipeline=self,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -639,6 +681,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||||
weight_name (`str`, *optional*, defaults to None):
|
weight_name (`str`, *optional*, defaults to None):
|
||||||
Name of the serialized state dict file.
|
Name of the serialized state dict file.
|
||||||
|
return_metadata (`bool`):
|
||||||
|
If state dict metadata should be returned. Is only supported when the state dict has a safetensors
|
||||||
|
extension.
|
||||||
"""
|
"""
|
||||||
# Load the main state dict first which has the LoRA layers for either of
|
# Load the main state dict first which has the LoRA layers for either of
|
||||||
# UNet and text encoder or both.
|
# UNet and text encoder or both.
|
||||||
@@ -652,6 +697,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
weight_name = kwargs.pop("weight_name", None)
|
weight_name = kwargs.pop("weight_name", None)
|
||||||
unet_config = kwargs.pop("unet_config", None)
|
unet_config = kwargs.pop("unet_config", None)
|
||||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||||
|
return_metadata = kwargs.pop("return_metadata", False)
|
||||||
|
|
||||||
allow_pickle = False
|
allow_pickle = False
|
||||||
if use_safetensors is None:
|
if use_safetensors is None:
|
||||||
@@ -663,7 +709,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
"framework": "pytorch",
|
"framework": "pytorch",
|
||||||
}
|
}
|
||||||
|
|
||||||
state_dict = cls._fetch_state_dict(
|
state_dict, metadata = cls._fetch_state_dict(
|
||||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||||
weight_name=weight_name,
|
weight_name=weight_name,
|
||||||
use_safetensors=use_safetensors,
|
use_safetensors=use_safetensors,
|
||||||
@@ -695,11 +741,14 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
||||||
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
|
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
|
||||||
|
|
||||||
return state_dict, network_alphas
|
if return_metadata:
|
||||||
|
return state_dict, network_alphas, metadata
|
||||||
|
else:
|
||||||
|
return state_dict, network_alphas
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
|
||||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
|
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, config=None, _pipeline=None):
|
||||||
"""
|
"""
|
||||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||||
|
|
||||||
@@ -717,6 +766,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
adapter_name (`str`, *optional*):
|
adapter_name (`str`, *optional*):
|
||||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||||
`default_{i}` where i is the total number of adapters being loaded.
|
`default_{i}` where i is the total number of adapters being loaded.
|
||||||
|
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
|
||||||
"""
|
"""
|
||||||
if not USE_PEFT_BACKEND:
|
if not USE_PEFT_BACKEND:
|
||||||
raise ValueError("PEFT backend is required for this method.")
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
@@ -730,7 +780,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
# Load the layers corresponding to UNet.
|
# Load the layers corresponding to UNet.
|
||||||
logger.info(f"Loading {cls.unet_name}.")
|
logger.info(f"Loading {cls.unet_name}.")
|
||||||
unet.load_attn_procs(
|
unet.load_attn_procs(
|
||||||
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
|
state_dict,
|
||||||
|
network_alphas=network_alphas,
|
||||||
|
adapter_name=adapter_name,
|
||||||
|
config=config,
|
||||||
|
_pipeline=_pipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -743,6 +797,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
prefix=None,
|
prefix=None,
|
||||||
lora_scale=1.0,
|
lora_scale=1.0,
|
||||||
adapter_name=None,
|
adapter_name=None,
|
||||||
|
config=None,
|
||||||
_pipeline=None,
|
_pipeline=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -764,6 +819,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
adapter_name (`str`, *optional*):
|
adapter_name (`str`, *optional*):
|
||||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||||
`default_{i}` where i is the total number of adapters being loaded.
|
`default_{i}` where i is the total number of adapters being loaded.
|
||||||
|
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
|
||||||
"""
|
"""
|
||||||
if not USE_PEFT_BACKEND:
|
if not USE_PEFT_BACKEND:
|
||||||
raise ValueError("PEFT backend is required for this method.")
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
@@ -776,6 +832,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
keys = list(state_dict.keys())
|
keys = list(state_dict.keys())
|
||||||
prefix = cls.text_encoder_name if prefix is None else prefix
|
prefix = cls.text_encoder_name if prefix is None else prefix
|
||||||
|
|
||||||
|
if config is not None and len(config) > 0:
|
||||||
|
config = json.loads(config[prefix])
|
||||||
|
|
||||||
# Safe prefix to check with.
|
# Safe prefix to check with.
|
||||||
if any(cls.text_encoder_name in key for key in keys):
|
if any(cls.text_encoder_name in key for key in keys):
|
||||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||||
@@ -814,7 +873,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
lora_config_kwargs = get_peft_kwargs(
|
||||||
|
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
|
||||||
|
)
|
||||||
if "use_dora" in lora_config_kwargs:
|
if "use_dora" in lora_config_kwargs:
|
||||||
if lora_config_kwargs["use_dora"]:
|
if lora_config_kwargs["use_dora"]:
|
||||||
if is_peft_version("<", "0.9.0"):
|
if is_peft_version("<", "0.9.0"):
|
||||||
@@ -859,6 +920,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
|
unet_lora_config: dict = None,
|
||||||
|
text_encoder_lora_config: dict = None,
|
||||||
|
text_encoder_2_lora_config: dict = None,
|
||||||
is_main_process: bool = True,
|
is_main_process: bool = True,
|
||||||
weight_name: str = None,
|
weight_name: str = None,
|
||||||
save_function: Callable = None,
|
save_function: Callable = None,
|
||||||
@@ -878,6 +942,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||||
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
|
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
|
||||||
encoder LoRA state dict because it comes from 🤗 Transformers.
|
encoder LoRA state dict because it comes from 🤗 Transformers.
|
||||||
|
unet_lora_config (`dict`, *optional*):
|
||||||
|
LoRA configuration used to train the `unet_lora_layers`.
|
||||||
|
text_encoder_lora_config (`dict`, *optional*):
|
||||||
|
LoRA configuration used to train the `text_encoder_lora_layers`.
|
||||||
|
text_encoder_2_lora_config (`dict`, *optional*):
|
||||||
|
LoRA configuration used to train the `text_encoder_2_lora_layers`.
|
||||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||||
@@ -890,6 +960,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||||
"""
|
"""
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -898,15 +969,25 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
|
|
||||||
if unet_lora_layers:
|
if unet_lora_layers:
|
||||||
state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
|
state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
|
||||||
|
if unet_lora_config is not None:
|
||||||
|
unet_metadata = cls.pack_metadata(unet_lora_config, "unet")
|
||||||
|
metadata.update(unet_metadata)
|
||||||
|
|
||||||
if text_encoder_lora_layers:
|
if text_encoder_lora_layers:
|
||||||
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
|
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||||
|
if text_encoder_lora_config is not None:
|
||||||
|
te_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder")
|
||||||
|
metadata.update(te_metadata)
|
||||||
|
|
||||||
if text_encoder_2_lora_layers:
|
if text_encoder_2_lora_layers:
|
||||||
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||||
|
if text_encoder_2_lora_config is not None:
|
||||||
|
te2_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder_2")
|
||||||
|
metadata.update(te2_metadata)
|
||||||
|
|
||||||
cls.write_lora_layers(
|
cls.write_lora_layers(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
|
metadata=metadata,
|
||||||
save_directory=save_directory,
|
save_directory=save_directory,
|
||||||
is_main_process=is_main_process,
|
is_main_process=is_main_process,
|
||||||
weight_name=weight_name,
|
weight_name=weight_name,
|
||||||
@@ -970,10 +1051,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
|
||||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
|
||||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
|
||||||
LoRA parameters then it won't have any effect.
|
|
||||||
"""
|
"""
|
||||||
super().unfuse_lora(components=components)
|
super().unfuse_lora(components=components)
|
||||||
|
|
||||||
@@ -1041,6 +1118,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
allowed by Git.
|
allowed by Git.
|
||||||
subfolder (`str`, *optional*, defaults to `""`):
|
subfolder (`str`, *optional*, defaults to `""`):
|
||||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||||
|
return_metadata (`bool`, *optional*):
|
||||||
|
If state dict metadata should be returned. Is only supported when the state dict has a safetensors
|
||||||
|
extension.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Load the main state dict first which has the LoRA layers for either of
|
# Load the main state dict first which has the LoRA layers for either of
|
||||||
@@ -1054,6 +1134,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
subfolder = kwargs.pop("subfolder", None)
|
subfolder = kwargs.pop("subfolder", None)
|
||||||
weight_name = kwargs.pop("weight_name", None)
|
weight_name = kwargs.pop("weight_name", None)
|
||||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||||
|
return_metadata = kwargs.pop("return_metadata", False)
|
||||||
|
|
||||||
allow_pickle = False
|
allow_pickle = False
|
||||||
if use_safetensors is None:
|
if use_safetensors is None:
|
||||||
@@ -1065,7 +1146,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
"framework": "pytorch",
|
"framework": "pytorch",
|
||||||
}
|
}
|
||||||
|
|
||||||
state_dict = cls._fetch_state_dict(
|
state_dict, metadata = cls._fetch_state_dict(
|
||||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||||
weight_name=weight_name,
|
weight_name=weight_name,
|
||||||
use_safetensors=use_safetensors,
|
use_safetensors=use_safetensors,
|
||||||
@@ -1080,13 +1161,17 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
allow_pickle=allow_pickle,
|
allow_pickle=allow_pickle,
|
||||||
)
|
)
|
||||||
|
|
||||||
return state_dict
|
# Otherwise, this would be a breaking change.
|
||||||
|
if return_metadata:
|
||||||
|
return state_dict, metadata
|
||||||
|
else:
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def load_lora_weights(
|
def load_lora_weights(
|
||||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||||
`self.text_encoder`.
|
`self.text_encoder`.
|
||||||
|
|
||||||
All kwargs are forwarded to `self.lora_state_dict`.
|
All kwargs are forwarded to `self.lora_state_dict`.
|
||||||
@@ -1114,7 +1199,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||||
|
|
||||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
state_dict, metadata = self.lora_state_dict(
|
||||||
|
pretrained_model_name_or_path_or_dict, **kwargs, return_metadata=True
|
||||||
|
)
|
||||||
|
|
||||||
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
||||||
if not is_correct_format:
|
if not is_correct_format:
|
||||||
@@ -1124,6 +1211,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
state_dict,
|
state_dict,
|
||||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||||
adapter_name=adapter_name,
|
adapter_name=adapter_name,
|
||||||
|
config=metadata,
|
||||||
_pipeline=self,
|
_pipeline=self,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1136,6 +1224,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
prefix="text_encoder",
|
prefix="text_encoder",
|
||||||
lora_scale=self.lora_scale,
|
lora_scale=self.lora_scale,
|
||||||
adapter_name=adapter_name,
|
adapter_name=adapter_name,
|
||||||
|
config=metadata,
|
||||||
_pipeline=self,
|
_pipeline=self,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1148,11 +1237,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
prefix="text_encoder_2",
|
prefix="text_encoder_2",
|
||||||
lora_scale=self.lora_scale,
|
lora_scale=self.lora_scale,
|
||||||
adapter_name=adapter_name,
|
adapter_name=adapter_name,
|
||||||
|
config=metadata,
|
||||||
_pipeline=self,
|
_pipeline=self,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
|
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, config=None, _pipeline=None):
|
||||||
"""
|
"""
|
||||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||||
|
|
||||||
@@ -1166,6 +1256,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
adapter_name (`str`, *optional*):
|
adapter_name (`str`, *optional*):
|
||||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||||
`default_{i}` where i is the total number of adapters being loaded.
|
`default_{i}` where i is the total number of adapters being loaded.
|
||||||
|
config (`dict`): Configuration that was used to train this LoRA.
|
||||||
"""
|
"""
|
||||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||||
|
|
||||||
@@ -1192,7 +1283,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
if "lora_B" in key:
|
if "lora_B" in key:
|
||||||
rank[key] = val.shape[1]
|
rank[key] = val.shape[1]
|
||||||
|
|
||||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
|
if config is not None and isinstance(config, dict) and len(config) > 0:
|
||||||
|
config = json.loads(config[cls.transformer_name])
|
||||||
|
lora_config_kwargs = get_peft_kwargs(
|
||||||
|
rank, network_alpha_dict=None, config=config, peft_state_dict=state_dict
|
||||||
|
)
|
||||||
if "use_dora" in lora_config_kwargs:
|
if "use_dora" in lora_config_kwargs:
|
||||||
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -1239,6 +1334,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
prefix=None,
|
prefix=None,
|
||||||
lora_scale=1.0,
|
lora_scale=1.0,
|
||||||
adapter_name=None,
|
adapter_name=None,
|
||||||
|
config=None,
|
||||||
_pipeline=None,
|
_pipeline=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1260,6 +1356,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
adapter_name (`str`, *optional*):
|
adapter_name (`str`, *optional*):
|
||||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||||
`default_{i}` where i is the total number of adapters being loaded.
|
`default_{i}` where i is the total number of adapters being loaded.
|
||||||
|
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
|
||||||
"""
|
"""
|
||||||
if not USE_PEFT_BACKEND:
|
if not USE_PEFT_BACKEND:
|
||||||
raise ValueError("PEFT backend is required for this method.")
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
@@ -1272,6 +1369,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
keys = list(state_dict.keys())
|
keys = list(state_dict.keys())
|
||||||
prefix = cls.text_encoder_name if prefix is None else prefix
|
prefix = cls.text_encoder_name if prefix is None else prefix
|
||||||
|
|
||||||
|
if config is not None and len(config) > 0:
|
||||||
|
config = json.loads(config[prefix])
|
||||||
|
|
||||||
# Safe prefix to check with.
|
# Safe prefix to check with.
|
||||||
if any(cls.text_encoder_name in key for key in keys):
|
if any(cls.text_encoder_name in key for key in keys):
|
||||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||||
@@ -1310,7 +1410,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
lora_config_kwargs = get_peft_kwargs(
|
||||||
|
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
|
||||||
|
)
|
||||||
if "use_dora" in lora_config_kwargs:
|
if "use_dora" in lora_config_kwargs:
|
||||||
if lora_config_kwargs["use_dora"]:
|
if lora_config_kwargs["use_dora"]:
|
||||||
if is_peft_version("<", "0.9.0"):
|
if is_peft_version("<", "0.9.0"):
|
||||||
@@ -1349,12 +1451,16 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
# Unsafe code />
|
# Unsafe code />
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
|
||||||
def save_lora_weights(
|
def save_lora_weights(
|
||||||
cls,
|
cls,
|
||||||
save_directory: Union[str, os.PathLike],
|
save_directory: Union[str, os.PathLike],
|
||||||
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
|
transformer_lora_config: dict = None,
|
||||||
|
text_encoder_lora_config: dict = None,
|
||||||
|
text_encoder_2_lora_config: dict = None,
|
||||||
is_main_process: bool = True,
|
is_main_process: bool = True,
|
||||||
weight_name: str = None,
|
weight_name: str = None,
|
||||||
save_function: Callable = None,
|
save_function: Callable = None,
|
||||||
@@ -1374,6 +1480,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||||
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
|
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
|
||||||
encoder LoRA state dict because it comes from 🤗 Transformers.
|
encoder LoRA state dict because it comes from 🤗 Transformers.
|
||||||
|
transformer_lora_config (`dict`, *optional*):
|
||||||
|
LoRA configuration used to train the `transformer_lora_layers`.
|
||||||
|
text_encoder_lora_config (`dict`, *optional*):
|
||||||
|
LoRA configuration used to train the `text_encoder_lora_layers`.
|
||||||
|
text_encoder_2_lora_config (`dict`, *optional*):
|
||||||
|
LoRA configuration used to train the `text_encoder_2_lora_layers`.
|
||||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||||
@@ -1386,24 +1498,34 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||||
"""
|
"""
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
|
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
|
||||||
)
|
)
|
||||||
|
|
||||||
if transformer_lora_layers:
|
if transformer_lora_layers:
|
||||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
state_dict.update(cls.pack_weights(transformer_lora_layers, "transformer"))
|
||||||
|
if transformer_lora_config is not None:
|
||||||
|
transformer_metadata = cls.pack_metadata(transformer_lora_config, "transformer")
|
||||||
|
metadata.update(transformer_metadata)
|
||||||
|
|
||||||
if text_encoder_lora_layers:
|
if text_encoder_lora_layers:
|
||||||
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
|
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||||
|
if text_encoder_lora_config is not None:
|
||||||
|
te_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder")
|
||||||
|
metadata.update(te_metadata)
|
||||||
|
|
||||||
if text_encoder_2_lora_layers:
|
if text_encoder_2_lora_layers:
|
||||||
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||||
|
if text_encoder_2_lora_config is not None:
|
||||||
|
te2_metadata = cls.pack_metadata(text_encoder_lora_config, "text_encoder_2")
|
||||||
|
metadata.update(te2_metadata)
|
||||||
|
|
||||||
# Save the model
|
|
||||||
cls.write_lora_layers(
|
cls.write_lora_layers(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
|
metadata=metadata,
|
||||||
save_directory=save_directory,
|
save_directory=save_directory,
|
||||||
is_main_process=is_main_process,
|
is_main_process=is_main_process,
|
||||||
weight_name=weight_name,
|
weight_name=weight_name,
|
||||||
@@ -1411,6 +1533,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
safe_serialization=safe_serialization,
|
safe_serialization=safe_serialization,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
|
||||||
def fuse_lora(
|
def fuse_lora(
|
||||||
self,
|
self,
|
||||||
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
|
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
|
||||||
@@ -1454,6 +1577,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
|
||||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
|
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Reverses the effect of
|
Reverses the effect of
|
||||||
@@ -1467,10 +1591,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
|
||||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
|
||||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
|
||||||
LoRA parameters then it won't have any effect.
|
|
||||||
"""
|
"""
|
||||||
super().unfuse_lora(components=components)
|
super().unfuse_lora(components=components)
|
||||||
|
|
||||||
@@ -1538,6 +1658,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|||||||
allowed by Git.
|
allowed by Git.
|
||||||
subfolder (`str`, *optional*, defaults to `""`):
|
subfolder (`str`, *optional*, defaults to `""`):
|
||||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||||
|
return_metadata (`bool`, *optional*):
|
||||||
|
If state dict metadata should be returned. Is only supported when the state dict has a safetensors
|
||||||
|
extension.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Load the main state dict first which has the LoRA layers for either of
|
# Load the main state dict first which has the LoRA layers for either of
|
||||||
@@ -1551,6 +1674,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|||||||
subfolder = kwargs.pop("subfolder", None)
|
subfolder = kwargs.pop("subfolder", None)
|
||||||
weight_name = kwargs.pop("weight_name", None)
|
weight_name = kwargs.pop("weight_name", None)
|
||||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||||
|
return_metadata = kwargs.pop("return_metadata", False)
|
||||||
|
|
||||||
allow_pickle = False
|
allow_pickle = False
|
||||||
if use_safetensors is None:
|
if use_safetensors is None:
|
||||||
@@ -1562,7 +1686,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|||||||
"framework": "pytorch",
|
"framework": "pytorch",
|
||||||
}
|
}
|
||||||
|
|
||||||
state_dict = cls._fetch_state_dict(
|
state_dict, metadata = cls._fetch_state_dict(
|
||||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||||
weight_name=weight_name,
|
weight_name=weight_name,
|
||||||
use_safetensors=use_safetensors,
|
use_safetensors=use_safetensors,
|
||||||
@@ -1577,7 +1701,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|||||||
allow_pickle=allow_pickle,
|
allow_pickle=allow_pickle,
|
||||||
)
|
)
|
||||||
|
|
||||||
return state_dict
|
# Otherwise, this would be a breaking change.
|
||||||
|
if return_metadata:
|
||||||
|
return state_dict, metadata
|
||||||
|
else:
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def load_lora_weights(
|
def load_lora_weights(
|
||||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||||
@@ -1611,7 +1739,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|||||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||||
|
|
||||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
state_dict, metadata = self.lora_state_dict(
|
||||||
|
pretrained_model_name_or_path_or_dict, **kwargs, return_metadata=True
|
||||||
|
)
|
||||||
|
|
||||||
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
||||||
if not is_correct_format:
|
if not is_correct_format:
|
||||||
@@ -1621,6 +1751,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|||||||
state_dict,
|
state_dict,
|
||||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||||
adapter_name=adapter_name,
|
adapter_name=adapter_name,
|
||||||
|
config=metadata,
|
||||||
_pipeline=self,
|
_pipeline=self,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1633,12 +1764,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|||||||
prefix="text_encoder",
|
prefix="text_encoder",
|
||||||
lora_scale=self.lora_scale,
|
lora_scale=self.lora_scale,
|
||||||
adapter_name=adapter_name,
|
adapter_name=adapter_name,
|
||||||
|
config=metadata,
|
||||||
_pipeline=self,
|
_pipeline=self,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
|
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
|
||||||
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
|
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, config=None, _pipeline=None):
|
||||||
"""
|
"""
|
||||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||||
|
|
||||||
@@ -1652,6 +1784,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|||||||
adapter_name (`str`, *optional*):
|
adapter_name (`str`, *optional*):
|
||||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||||
`default_{i}` where i is the total number of adapters being loaded.
|
`default_{i}` where i is the total number of adapters being loaded.
|
||||||
|
config (`dict`): Configuration that was used to train this LoRA.
|
||||||
"""
|
"""
|
||||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||||
|
|
||||||
@@ -1678,7 +1811,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|||||||
if "lora_B" in key:
|
if "lora_B" in key:
|
||||||
rank[key] = val.shape[1]
|
rank[key] = val.shape[1]
|
||||||
|
|
||||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
|
if config is not None and isinstance(config, dict) and len(config) > 0:
|
||||||
|
config = json.loads(config[cls.transformer_name])
|
||||||
|
lora_config_kwargs = get_peft_kwargs(
|
||||||
|
rank, network_alpha_dict=None, config=config, peft_state_dict=state_dict
|
||||||
|
)
|
||||||
if "use_dora" in lora_config_kwargs:
|
if "use_dora" in lora_config_kwargs:
|
||||||
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -1725,6 +1862,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|||||||
prefix=None,
|
prefix=None,
|
||||||
lora_scale=1.0,
|
lora_scale=1.0,
|
||||||
adapter_name=None,
|
adapter_name=None,
|
||||||
|
config=None,
|
||||||
_pipeline=None,
|
_pipeline=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1746,6 +1884,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|||||||
adapter_name (`str`, *optional*):
|
adapter_name (`str`, *optional*):
|
||||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||||
`default_{i}` where i is the total number of adapters being loaded.
|
`default_{i}` where i is the total number of adapters being loaded.
|
||||||
|
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
|
||||||
"""
|
"""
|
||||||
if not USE_PEFT_BACKEND:
|
if not USE_PEFT_BACKEND:
|
||||||
raise ValueError("PEFT backend is required for this method.")
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
@@ -1758,6 +1897,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|||||||
keys = list(state_dict.keys())
|
keys = list(state_dict.keys())
|
||||||
prefix = cls.text_encoder_name if prefix is None else prefix
|
prefix = cls.text_encoder_name if prefix is None else prefix
|
||||||
|
|
||||||
|
if config is not None and len(config) > 0:
|
||||||
|
config = json.loads(config[prefix])
|
||||||
|
|
||||||
# Safe prefix to check with.
|
# Safe prefix to check with.
|
||||||
if any(cls.text_encoder_name in key for key in keys):
|
if any(cls.text_encoder_name in key for key in keys):
|
||||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||||
@@ -1796,7 +1938,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|||||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
lora_config_kwargs = get_peft_kwargs(
|
||||||
|
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
|
||||||
|
)
|
||||||
if "use_dora" in lora_config_kwargs:
|
if "use_dora" in lora_config_kwargs:
|
||||||
if lora_config_kwargs["use_dora"]:
|
if lora_config_kwargs["use_dora"]:
|
||||||
if is_peft_version("<", "0.9.0"):
|
if is_peft_version("<", "0.9.0"):
|
||||||
@@ -1841,6 +1985,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|||||||
save_directory: Union[str, os.PathLike],
|
save_directory: Union[str, os.PathLike],
|
||||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||||
|
transformer_lora_config: dict = None,
|
||||||
|
text_encoder_lora_config: dict = None,
|
||||||
is_main_process: bool = True,
|
is_main_process: bool = True,
|
||||||
weight_name: str = None,
|
weight_name: str = None,
|
||||||
save_function: Callable = None,
|
save_function: Callable = None,
|
||||||
@@ -1857,6 +2003,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|||||||
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||||
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
||||||
encoder LoRA state dict because it comes from 🤗 Transformers.
|
encoder LoRA state dict because it comes from 🤗 Transformers.
|
||||||
|
transformer_lora_config (`dict`, *optional*):
|
||||||
|
LoRA configuration used to train the `transformer_lora_layers`.
|
||||||
|
text_encoder_lora_config (`dict`, *optional*):
|
||||||
|
LoRA configuration used to train the `text_encoder_lora_layers`.
|
||||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||||
@@ -1869,19 +2019,27 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|||||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||||
"""
|
"""
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
if not (transformer_lora_layers or text_encoder_lora_layers):
|
if not (transformer_lora_layers or text_encoder_lora_layers):
|
||||||
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
|
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
|
||||||
|
|
||||||
if transformer_lora_layers:
|
if transformer_lora_layers:
|
||||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||||
|
if transformer_lora_config:
|
||||||
|
transformer_metadata = cls.pack_metadata(transformer_lora_config, cls.transformer_name)
|
||||||
|
metadata.update(transformer_metadata)
|
||||||
|
|
||||||
if text_encoder_lora_layers:
|
if text_encoder_lora_layers:
|
||||||
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
||||||
|
if text_encoder_lora_config:
|
||||||
|
te_metadata = cls.pack_metadata(text_encoder_lora_config, cls.text_encoder_name)
|
||||||
|
metadata.update(te_metadata)
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
cls.write_lora_layers(
|
cls.write_lora_layers(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
|
metadata=metadata,
|
||||||
save_directory=save_directory,
|
save_directory=save_directory,
|
||||||
is_main_process=is_main_process,
|
is_main_process=is_main_process,
|
||||||
weight_name=weight_name,
|
weight_name=weight_name,
|
||||||
@@ -1933,6 +2091,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|||||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
||||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Reverses the effect of
|
Reverses the effect of
|
||||||
@@ -2051,6 +2210,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|||||||
prefix=None,
|
prefix=None,
|
||||||
lora_scale=1.0,
|
lora_scale=1.0,
|
||||||
adapter_name=None,
|
adapter_name=None,
|
||||||
|
config=None,
|
||||||
_pipeline=None,
|
_pipeline=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -2072,6 +2232,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|||||||
adapter_name (`str`, *optional*):
|
adapter_name (`str`, *optional*):
|
||||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||||
`default_{i}` where i is the total number of adapters being loaded.
|
`default_{i}` where i is the total number of adapters being loaded.
|
||||||
|
config (`dict`, *optional*): LoRA configuration (`LoraConfig` dict) used during training.
|
||||||
"""
|
"""
|
||||||
if not USE_PEFT_BACKEND:
|
if not USE_PEFT_BACKEND:
|
||||||
raise ValueError("PEFT backend is required for this method.")
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
@@ -2084,6 +2245,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|||||||
keys = list(state_dict.keys())
|
keys = list(state_dict.keys())
|
||||||
prefix = cls.text_encoder_name if prefix is None else prefix
|
prefix = cls.text_encoder_name if prefix is None else prefix
|
||||||
|
|
||||||
|
if config is not None and len(config) > 0:
|
||||||
|
config = json.loads(config[prefix])
|
||||||
|
|
||||||
# Safe prefix to check with.
|
# Safe prefix to check with.
|
||||||
if any(cls.text_encoder_name in key for key in keys):
|
if any(cls.text_encoder_name in key for key in keys):
|
||||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||||
@@ -2122,7 +2286,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|||||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
lora_config_kwargs = get_peft_kwargs(
|
||||||
|
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
|
||||||
|
)
|
||||||
if "use_dora" in lora_config_kwargs:
|
if "use_dora" in lora_config_kwargs:
|
||||||
if lora_config_kwargs["use_dora"]:
|
if lora_config_kwargs["use_dora"]:
|
||||||
if is_peft_version("<", "0.9.0"):
|
if is_peft_version("<", "0.9.0"):
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
@@ -115,6 +116,7 @@ class UNet2DConditionLoadersMixin:
|
|||||||
`default_{i}` where i is the total number of adapters being loaded.
|
`default_{i}` where i is the total number of adapters being loaded.
|
||||||
weight_name (`str`, *optional*, defaults to None):
|
weight_name (`str`, *optional*, defaults to None):
|
||||||
Name of the serialized state dict file.
|
Name of the serialized state dict file.
|
||||||
|
config: (`dict`, *optional*)
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -143,6 +145,7 @@ class UNet2DConditionLoadersMixin:
|
|||||||
_pipeline = kwargs.pop("_pipeline", None)
|
_pipeline = kwargs.pop("_pipeline", None)
|
||||||
network_alphas = kwargs.pop("network_alphas", None)
|
network_alphas = kwargs.pop("network_alphas", None)
|
||||||
allow_pickle = False
|
allow_pickle = False
|
||||||
|
config = kwargs.pop("config", None)
|
||||||
|
|
||||||
if use_safetensors is None:
|
if use_safetensors is None:
|
||||||
use_safetensors = True
|
use_safetensors = True
|
||||||
@@ -208,6 +211,7 @@ class UNet2DConditionLoadersMixin:
|
|||||||
unet_identifier_key=self.unet_name,
|
unet_identifier_key=self.unet_name,
|
||||||
network_alphas=network_alphas,
|
network_alphas=network_alphas,
|
||||||
adapter_name=adapter_name,
|
adapter_name=adapter_name,
|
||||||
|
config=config,
|
||||||
_pipeline=_pipeline,
|
_pipeline=_pipeline,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -268,7 +272,7 @@ class UNet2DConditionLoadersMixin:
|
|||||||
|
|
||||||
return attn_processors
|
return attn_processors
|
||||||
|
|
||||||
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
|
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, config=None):
|
||||||
# This method does the following things:
|
# This method does the following things:
|
||||||
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
|
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
|
||||||
# format. For legacy format no filtering is applied.
|
# format. For legacy format no filtering is applied.
|
||||||
@@ -316,7 +320,10 @@ class UNet2DConditionLoadersMixin:
|
|||||||
if "lora_B" in key:
|
if "lora_B" in key:
|
||||||
rank[key] = val.shape[1]
|
rank[key] = val.shape[1]
|
||||||
|
|
||||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
|
if config is not None and isinstance(config, dict) and len(config) > 0:
|
||||||
|
config = json.loads(config["unet"])
|
||||||
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True)
|
||||||
|
|
||||||
if "use_dora" in lora_config_kwargs:
|
if "use_dora" in lora_config_kwargs:
|
||||||
if lora_config_kwargs["use_dora"]:
|
if lora_config_kwargs["use_dora"]:
|
||||||
if is_peft_version("<", "0.9.0"):
|
if is_peft_version("<", "0.9.0"):
|
||||||
|
|||||||
@@ -21,9 +21,12 @@ from typing import Optional
|
|||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
from . import logging
|
||||||
from .import_utils import is_peft_available, is_torch_available
|
from .import_utils import is_peft_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -147,11 +150,29 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
|
|||||||
module.set_scale(adapter_name, 1.0)
|
module.set_scale(adapter_name, 1.0)
|
||||||
|
|
||||||
|
|
||||||
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
|
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, is_unet=True):
|
||||||
rank_pattern = {}
|
rank_pattern = {}
|
||||||
alpha_pattern = {}
|
alpha_pattern = {}
|
||||||
r = lora_alpha = list(rank_dict.values())[0]
|
r = lora_alpha = list(rank_dict.values())[0]
|
||||||
|
|
||||||
|
# Try to retrieve config.
|
||||||
|
alpha_retrieved = False
|
||||||
|
if config is not None:
|
||||||
|
lora_alpha = config["lora_alpha"] if "lora_alpha" in config else lora_alpha
|
||||||
|
alpha_retrieved = True
|
||||||
|
|
||||||
|
# We simply ignore the `alpha_pattern` and `rank_pattern` if they are found
|
||||||
|
# in the `config`. This is because:
|
||||||
|
# 1. We determine `rank_pattern` from the `rank_dict`.
|
||||||
|
# 2. When `network_alpha_dict` is passed that means the underlying checkpoint
|
||||||
|
# is a non-diffusers checkpoint.
|
||||||
|
# More details: https://github.com/huggingface/diffusers/pull/9143#discussion_r1711491175
|
||||||
|
if config.get("alpha_pattern", None) is not None:
|
||||||
|
logger.warning("`alpha_pattern` found in the LoRA config. This will be ignored.")
|
||||||
|
|
||||||
|
if config.get("rank_pattern", None) is not None:
|
||||||
|
logger.warning("`rank_pattern` found in the LoRA config. This will be ignored.")
|
||||||
|
|
||||||
if len(set(rank_dict.values())) > 1:
|
if len(set(rank_dict.values())) > 1:
|
||||||
# get the rank occuring the most number of times
|
# get the rank occuring the most number of times
|
||||||
r = collections.Counter(rank_dict.values()).most_common()[0][0]
|
r = collections.Counter(rank_dict.values()).most_common()[0][0]
|
||||||
@@ -160,7 +181,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
|
|||||||
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
|
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
|
||||||
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
|
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
|
||||||
|
|
||||||
if network_alpha_dict is not None and len(network_alpha_dict) > 0:
|
if not alpha_retrieved and network_alpha_dict is not None and len(network_alpha_dict) > 0:
|
||||||
if len(set(network_alpha_dict.values())) > 1:
|
if len(set(network_alpha_dict.values())) > 1:
|
||||||
# get the alpha occuring the most number of times
|
# get the alpha occuring the most number of times
|
||||||
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
|
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import torch
|
|||||||
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||||
|
|
||||||
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
|
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
|
||||||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
|
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, torch_device
|
||||||
|
|
||||||
|
|
||||||
sys.path.append(".")
|
sys.path.append(".")
|
||||||
@@ -28,6 +28,7 @@ from utils import PeftLoraLoaderMixinTests # noqa: E402
|
|||||||
|
|
||||||
|
|
||||||
@require_peft_backend
|
@require_peft_backend
|
||||||
|
@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.")
|
||||||
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||||
pipeline_class = FluxPipeline
|
pipeline_class = FluxPipeline
|
||||||
scheduler_cls = FlowMatchEulerDiscreteScheduler()
|
scheduler_cls = FlowMatchEulerDiscreteScheduler()
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ class PeftLoraLoaderMixinTests:
|
|||||||
transformer_kwargs = None
|
transformer_kwargs = None
|
||||||
vae_kwargs = None
|
vae_kwargs = None
|
||||||
|
|
||||||
def get_dummy_components(self, scheduler_cls=None, use_dora=False):
|
def get_dummy_components(self, scheduler_cls=None, use_dora=False, alpha=None):
|
||||||
if self.unet_kwargs and self.transformer_kwargs:
|
if self.unet_kwargs and self.transformer_kwargs:
|
||||||
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
|
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
|
||||||
if self.has_two_text_encoders and self.has_three_text_encoders:
|
if self.has_two_text_encoders and self.has_three_text_encoders:
|
||||||
@@ -95,6 +95,7 @@ class PeftLoraLoaderMixinTests:
|
|||||||
|
|
||||||
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
|
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
|
||||||
rank = 4
|
rank = 4
|
||||||
|
alpha = rank if alpha is None else alpha
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
if self.unet_kwargs is not None:
|
if self.unet_kwargs is not None:
|
||||||
@@ -120,7 +121,7 @@ class PeftLoraLoaderMixinTests:
|
|||||||
|
|
||||||
text_lora_config = LoraConfig(
|
text_lora_config = LoraConfig(
|
||||||
r=rank,
|
r=rank,
|
||||||
lora_alpha=rank,
|
lora_alpha=alpha,
|
||||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||||
init_lora_weights=False,
|
init_lora_weights=False,
|
||||||
use_dora=use_dora,
|
use_dora=use_dora,
|
||||||
@@ -128,7 +129,7 @@ class PeftLoraLoaderMixinTests:
|
|||||||
|
|
||||||
denoiser_lora_config = LoraConfig(
|
denoiser_lora_config = LoraConfig(
|
||||||
r=rank,
|
r=rank,
|
||||||
lora_alpha=rank,
|
lora_alpha=alpha,
|
||||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||||
init_lora_weights=False,
|
init_lora_weights=False,
|
||||||
use_dora=use_dora,
|
use_dora=use_dora,
|
||||||
@@ -1752,3 +1753,85 @@ class PeftLoraLoaderMixinTests:
|
|||||||
|
|
||||||
_, _, inputs = self.get_dummy_inputs()
|
_, _, inputs = self.get_dummy_inputs()
|
||||||
_ = pipe(**inputs).images
|
_ = pipe(**inputs).images
|
||||||
|
|
||||||
|
def test_if_lora_alpha_is_correctly_parsed(self):
|
||||||
|
lora_alpha = 8
|
||||||
|
|
||||||
|
scheduler_class = FlowMatchEulerDiscreteScheduler if self.uses_flow_matching else DDIMScheduler
|
||||||
|
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
|
||||||
|
alpha=lora_alpha, scheduler_cls=scheduler_class
|
||||||
|
)
|
||||||
|
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe = pipe.to(torch_device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||||
|
|
||||||
|
pipe.text_encoder.add_adapter(text_lora_config)
|
||||||
|
if self.unet_kwargs is not None:
|
||||||
|
pipe.unet.add_adapter(denoiser_lora_config)
|
||||||
|
else:
|
||||||
|
pipe.transformer.add_adapter(denoiser_lora_config)
|
||||||
|
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||||
|
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||||
|
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||||
|
|
||||||
|
# Inference works?
|
||||||
|
_ = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
denoiser = pipe.unet if self.unet_kwargs else pipe.transformer
|
||||||
|
denoiser_state_dict = get_peft_model_state_dict(denoiser)
|
||||||
|
denoiser_lora_config = denoiser.peft_config["default"]
|
||||||
|
|
||||||
|
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
|
||||||
|
text_encoder_lora_config = pipe.text_encoder.peft_config["default"]
|
||||||
|
|
||||||
|
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||||
|
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||||
|
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
|
||||||
|
text_encoder_2_lora_config = pipe.text_encoder_2.peft_config["default"]
|
||||||
|
|
||||||
|
saving_kwargs = {
|
||||||
|
"save_directory": tmpdirname,
|
||||||
|
"text_encoder_lora_layers": text_encoder_state_dict,
|
||||||
|
"text_encoder_lora_config": text_encoder_lora_config,
|
||||||
|
}
|
||||||
|
if self.unet_kwargs is not None:
|
||||||
|
saving_kwargs.update(
|
||||||
|
{"unet_lora_layers": denoiser_state_dict, "unet_lora_config": denoiser_lora_config}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
saving_kwargs.update(
|
||||||
|
{"transformer_lora_layers": denoiser_state_dict, "transformer_lora_config": denoiser_lora_config}
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||||
|
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||||
|
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
|
||||||
|
saving_kwargs.update(
|
||||||
|
{
|
||||||
|
"text_encoder_2_lora_layers": text_encoder_2_state_dict,
|
||||||
|
"text_encoder_2_lora_config": text_encoder_2_lora_config,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pipeline_class.save_lora_weights(**saving_kwargs)
|
||||||
|
loaded_pipe = self.pipeline_class(**components)
|
||||||
|
loaded_pipe.load_lora_weights(tmpdirname)
|
||||||
|
|
||||||
|
# Inference works?
|
||||||
|
_ = loaded_pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||||
|
|
||||||
|
denoiser_loaded = pipe.unet if self.unet_kwargs is not None else pipe.transformer
|
||||||
|
assert (
|
||||||
|
denoiser_loaded.peft_config["default"].lora_alpha == lora_alpha
|
||||||
|
), "LoRA alpha not correctly loaded for UNet."
|
||||||
|
assert (
|
||||||
|
loaded_pipe.text_encoder.peft_config["default"].lora_alpha == lora_alpha
|
||||||
|
), "LoRA alpha not correctly loaded for text encoder."
|
||||||
|
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||||
|
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||||
|
assert (
|
||||||
|
loaded_pipe.text_encoder_2.peft_config["default"].lora_alpha == lora_alpha
|
||||||
|
), "LoRA alpha not correctly loaded for text encoder 2."
|
||||||
|
|||||||
Reference in New Issue
Block a user