mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-10 18:51:46 +08:00
Compare commits
6 Commits
diffusers-
...
refactor-l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aff05dc742 | ||
|
|
edf5ba6a17 | ||
|
|
9941f1f61b | ||
|
|
46a9db0336 | ||
|
|
370146e4e0 | ||
|
|
5cd45c24bf |
2
.github/workflows/build_documentation.yml
vendored
2
.github/workflows/build_documentation.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
package: diffusers
|
||||
notebook_folder: diffusers_doc
|
||||
languages: en ko zh ja pt
|
||||
|
||||
custom_container: diffusers/diffusers-doc-builder
|
||||
secrets:
|
||||
token: ${{ secrets.HUGGINGFACE_PUSH }}
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||
|
||||
1
.github/workflows/build_pr_documentation.yml
vendored
1
.github/workflows/build_pr_documentation.yml
vendored
@@ -20,3 +20,4 @@ jobs:
|
||||
install_libgl1: true
|
||||
package: diffusers
|
||||
languages: en ko zh ja pt
|
||||
custom_container: diffusers/diffusers-doc-builder
|
||||
|
||||
@@ -69,6 +69,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| UFOGen Scheduler | Scheduler for UFOGen Model (compatible with Stable Diffusion pipelines) | [UFOGen Scheduler](#ufogen-scheduler) | - | [dg845](https://github.com/dg845) |
|
||||
| Stable Diffusion XL IPEX Pipeline | Accelerate Stable Diffusion XL inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion XL on IPEX](#stable-diffusion-xl-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
|
||||
| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) |
|
||||
| FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) |
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
@@ -4035,6 +4036,93 @@ onestep_image = pipe(prompt, num_inference_steps=1).images[0]
|
||||
multistep_image = pipe(prompt, num_inference_steps=4).images[0]
|
||||
```
|
||||
|
||||
### FRESCO
|
||||
|
||||
This is the Diffusers implementation of zero-shot video-to-video translation pipeline [FRESCO](https://github.com/williamyang1991/FRESCO) (without Ebsynth postprocessing and background smooth). To run the code, please install gmflow. Then modify the path in `gmflow_dir`. After that, you can run the pipeline with:
|
||||
|
||||
```py
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from diffusers import ControlNetModel,DDIMScheduler, DiffusionPipeline
|
||||
import sys
|
||||
gmflow_dir = "/path/to/gmflow"
|
||||
sys.path.insert(0, gmflow_dir)
|
||||
|
||||
def video_to_frame(video_path: str, interval: int):
|
||||
vidcap = cv2.VideoCapture(video_path)
|
||||
success = True
|
||||
|
||||
count = 0
|
||||
res = []
|
||||
while success:
|
||||
count += 1
|
||||
success, image = vidcap.read()
|
||||
if count % interval != 1:
|
||||
continue
|
||||
if image is not None:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
res.append(image)
|
||||
if len(res) >= 8:
|
||||
break
|
||||
|
||||
vidcap.release()
|
||||
return res
|
||||
|
||||
|
||||
input_video_path = 'https://github.com/williamyang1991/FRESCO/raw/main/data/car-turn.mp4'
|
||||
output_video_path = 'car.gif'
|
||||
|
||||
# You can use any fintuned SD here
|
||||
model_path = 'SG161222/Realistic_Vision_V2.0'
|
||||
|
||||
prompt = 'a red car turns in the winter'
|
||||
a_prompt = ', RAW photo, subject, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3, '
|
||||
n_prompt = '(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation'
|
||||
|
||||
input_interval = 5
|
||||
frames = video_to_frame(
|
||||
input_video_path, input_interval)
|
||||
|
||||
control_frames = []
|
||||
# get canny image
|
||||
for frame in frames:
|
||||
image = cv2.Canny(frame, 50, 100)
|
||||
np_image = np.array(image)
|
||||
np_image = np_image[:, :, None]
|
||||
np_image = np.concatenate([np_image, np_image, np_image], axis=2)
|
||||
canny_image = Image.fromarray(np_image)
|
||||
control_frames.append(canny_image)
|
||||
|
||||
# You can use any ControlNet here
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"lllyasviel/sd-controlnet-canny").to('cuda')
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_path, controlnet=controlnet, custom_pipeline='fresco_v2v').to('cuda')
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
frames = [Image.fromarray(frame) for frame in frames]
|
||||
|
||||
output_frames = pipe(
|
||||
prompt + a_prompt,
|
||||
frames,
|
||||
control_frames,
|
||||
num_inference_steps=20,
|
||||
strength=0.75,
|
||||
controlnet_conditioning_scale=0.7,
|
||||
generator=generator,
|
||||
negative_prompt=n_prompt
|
||||
).images
|
||||
|
||||
output_frames[0].save(output_video_path, save_all=True,
|
||||
append_images=output_frames[1:], duration=100, loop=0)
|
||||
|
||||
```
|
||||
|
||||
# Perturbed-Attention Guidance
|
||||
|
||||
[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
|
||||
|
||||
2511
examples/community/fresco_v2v.py
Normal file
2511
examples/community/fresco_v2v.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -31,6 +31,7 @@ from ..utils import (
|
||||
is_transformers_available,
|
||||
is_xformers_available,
|
||||
)
|
||||
from ..utils.testing_utils import get_python_version
|
||||
from . import BaseDiffusersCLICommand
|
||||
|
||||
|
||||
@@ -105,6 +106,11 @@ class EnvironmentCommand(BaseDiffusersCLICommand):
|
||||
|
||||
xformers_version = xformers.__version__
|
||||
|
||||
if get_python_version() >= (3, 10):
|
||||
platform_info = f"{platform.freedesktop_os_release().get('PRETTY_NAME', None)} - {platform.platform()}"
|
||||
else:
|
||||
platform_info = platform.platform()
|
||||
|
||||
is_notebook_str = "Yes" if is_notebook() else "No"
|
||||
|
||||
is_google_colab_str = "Yes" if is_google_colab() else "No"
|
||||
@@ -152,7 +158,7 @@ class EnvironmentCommand(BaseDiffusersCLICommand):
|
||||
|
||||
info = {
|
||||
"🤗 Diffusers version": version,
|
||||
"Platform": f"{platform.freedesktop_os_release().get('PRETTY_NAME', None)} - {platform.platform()}",
|
||||
"Platform": platform_info,
|
||||
"Running on a notebook?": is_notebook_str,
|
||||
"Running on Google Colab?": is_google_colab_str,
|
||||
"Python version": platform.python_version(),
|
||||
|
||||
@@ -45,7 +45,7 @@ from ..utils import (
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
|
||||
from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -302,7 +302,7 @@ class LoraLoaderMixin:
|
||||
if unet_config is not None:
|
||||
# use unet config to remap block numbers
|
||||
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
||||
state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict)
|
||||
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
|
||||
|
||||
return state_dict, network_alphas
|
||||
|
||||
|
||||
@@ -13,6 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import is_peft_version, logging
|
||||
|
||||
@@ -123,164 +126,163 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
|
||||
def _convert_non_diffusers_lora_to_diffusers(
|
||||
state_dict: Dict[str, torch.Tensor], unet_name: str = "unet", text_encoder_name: str = "text_encoder"
|
||||
) -> Tuple[Dict[str, Any], Dict[str, float]]:
|
||||
def detect_dora_lora(state_dict: Dict[str, torch.Tensor]) -> Tuple[bool, bool, bool]:
|
||||
is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
|
||||
is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
|
||||
is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
|
||||
return is_unet_dora_lora, is_te_dora_lora, is_te2_dora_lora
|
||||
|
||||
def check_peft_version(is_unet_dora_lora: bool, is_te_dora_lora: bool, is_te2_dora_lora: bool):
|
||||
if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
|
||||
def rename_keys(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
key: str,
|
||||
unet_state_dict: Dict[str, torch.Tensor],
|
||||
te_state_dict: Dict[str, torch.Tensor],
|
||||
te2_state_dict: Dict[str, torch.Tensor],
|
||||
is_unet_dora_lora: bool,
|
||||
is_te_dora_lora: bool,
|
||||
is_te2_dora_lora: bool,
|
||||
):
|
||||
lora_name = key.split(".")[0]
|
||||
lora_name_up = lora_name + ".lora_up.weight"
|
||||
diffusers_name = key.replace(lora_name + ".", "").replace("_", ".")
|
||||
lora_type = lora_name.split("_")[1]
|
||||
|
||||
if lora_type == "unet":
|
||||
diffusers_name = _adjust_unet_names(diffusers_name)
|
||||
unet_state_dict = _populate_state_dict(
|
||||
unet_state_dict, state_dict, key, lora_name_up, diffusers_name, is_unet_dora_lora
|
||||
)
|
||||
else:
|
||||
diffusers_name = _adjust_text_encoder_names(diffusers_name)
|
||||
if lora_type in ["te", "te1"]:
|
||||
te_state_dict = _populate_state_dict(
|
||||
te_state_dict, state_dict, key, lora_name_up, diffusers_name, is_te_dora_lora
|
||||
)
|
||||
else:
|
||||
te2_state_dict = _populate_state_dict(
|
||||
te2_state_dict, state_dict, key, lora_name_up, diffusers_name, is_te2_dora_lora
|
||||
)
|
||||
|
||||
return unet_state_dict, te_state_dict, te2_state_dict
|
||||
|
||||
def _adjust_unet_names(name: str) -> str:
|
||||
replacements = [
|
||||
("input.blocks", "down_blocks"),
|
||||
("down.blocks", "down_blocks"),
|
||||
("middle.block", "mid_block"),
|
||||
("mid.block", "mid_block"),
|
||||
("output.blocks", "up_blocks"),
|
||||
("up.blocks", "up_blocks"),
|
||||
("transformer.blocks", "transformer_blocks"),
|
||||
("to.q.lora", "to_q_lora"),
|
||||
("to.k.lora", "to_k_lora"),
|
||||
("to.v.lora", "to_v_lora"),
|
||||
("to.out.0.lora", "to_out_lora"),
|
||||
("proj.in", "proj_in"),
|
||||
("proj.out", "proj_out"),
|
||||
("emb.layers", "time_emb_proj"),
|
||||
("time.emb.proj", "time_emb_proj"),
|
||||
("conv.shortcut", "conv_shortcut"),
|
||||
("skip.connection", "conv_shortcut"),
|
||||
]
|
||||
for old, new in replacements:
|
||||
name = name.replace(old, new)
|
||||
if "emb" in name and "time.emb.proj" not in name:
|
||||
pattern = r"\.\d+(?=\D*$)"
|
||||
name = re.sub(pattern, "", name, count=1)
|
||||
if ".in." in name:
|
||||
name = name.replace("in.layers.2", "conv1")
|
||||
if ".out." in name:
|
||||
name = name.replace("out.layers.3", "conv2")
|
||||
if "downsamplers" in name or "upsamplers" in name:
|
||||
name = name.replace("op", "conv")
|
||||
return name
|
||||
|
||||
def _adjust_text_encoder_names(name: str) -> str:
|
||||
replacements = [
|
||||
("text.model", "text_model"),
|
||||
("self.attn", "self_attn"),
|
||||
("q.proj.lora", "to_q_lora"),
|
||||
("k.proj.lora", "to_k_lora"),
|
||||
("v.proj.lora", "to_v_lora"),
|
||||
("out.proj.lora", "to_out_lora"),
|
||||
("text.projection", "text_projection"),
|
||||
]
|
||||
for old, new in replacements:
|
||||
name = name.replace(old, new)
|
||||
return name
|
||||
|
||||
def _populate_state_dict(state_dict, main_dict, down_key, up_key, name, is_dora_lora):
|
||||
state_dict[name] = main_dict.pop(down_key)
|
||||
state_dict[name.replace(".down.", ".up.")] = main_dict.pop(up_key)
|
||||
if is_dora_lora:
|
||||
dora_key = down_key.replace("lora_down.weight", "dora_scale")
|
||||
scale_key = "_lora.down." if "_lora.down." in name else ".lora.down."
|
||||
state_dict[name.replace(scale_key, ".lora_magnitude_vector.")] = main_dict.pop(dora_key)
|
||||
return state_dict
|
||||
|
||||
def update_network_alphas(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
network_alphas: Dict[str, float],
|
||||
diffusers_name: str,
|
||||
lora_name_alpha: str,
|
||||
):
|
||||
if lora_name_alpha in state_dict:
|
||||
alpha = state_dict.pop(lora_name_alpha).item()
|
||||
prefix = (
|
||||
"unet."
|
||||
if "unet" in lora_name_alpha
|
||||
else "text_encoder."
|
||||
if "te1" in lora_name_alpha
|
||||
else "text_encoder_2."
|
||||
)
|
||||
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
|
||||
network_alphas.update({new_name: alpha})
|
||||
|
||||
unet_state_dict = {}
|
||||
te_state_dict = {}
|
||||
te2_state_dict = {}
|
||||
network_alphas = {}
|
||||
is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
|
||||
is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
|
||||
is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
|
||||
|
||||
if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
is_unet_dora_lora, is_te_dora_lora, is_te2_dora_lora = detect_dora_lora(state_dict)
|
||||
check_peft_version(is_unet_dora_lora, is_te_dora_lora, is_te2_dora_lora)
|
||||
|
||||
# every down weight has a corresponding up weight and potentially an alpha weight
|
||||
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
|
||||
for key in lora_keys:
|
||||
unet_state_dict, te_state_dict, te2_state_dict = rename_keys(
|
||||
state_dict,
|
||||
key,
|
||||
unet_state_dict,
|
||||
te_state_dict,
|
||||
te2_state_dict,
|
||||
is_unet_dora_lora,
|
||||
is_te_dora_lora,
|
||||
is_te2_dora_lora,
|
||||
)
|
||||
lora_name = key.split(".")[0]
|
||||
lora_name_up = lora_name + ".lora_up.weight"
|
||||
lora_name_alpha = lora_name + ".alpha"
|
||||
diffusers_name = key.replace(lora_name + ".", "").replace("_", ".")
|
||||
update_network_alphas(state_dict, network_alphas, diffusers_name, lora_name_alpha)
|
||||
|
||||
if lora_name.startswith("lora_unet_"):
|
||||
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
|
||||
if state_dict:
|
||||
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
|
||||
|
||||
if "input.blocks" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
|
||||
else:
|
||||
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
|
||||
|
||||
if "middle.block" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
|
||||
else:
|
||||
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
|
||||
if "output.blocks" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
|
||||
else:
|
||||
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
|
||||
|
||||
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
|
||||
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
|
||||
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
|
||||
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
|
||||
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
|
||||
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
|
||||
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
|
||||
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
|
||||
|
||||
# SDXL specificity.
|
||||
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
|
||||
pattern = r"\.\d+(?=\D*$)"
|
||||
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
|
||||
if ".in." in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
|
||||
if ".out." in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
|
||||
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("op", "conv")
|
||||
if "skip" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
|
||||
|
||||
# LyCORIS specificity.
|
||||
if "time.emb.proj" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
|
||||
if "conv.shortcut" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
|
||||
|
||||
# General coverage.
|
||||
if "transformer_blocks" in diffusers_name:
|
||||
if "attn1" in diffusers_name or "attn2" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
|
||||
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
|
||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
elif "ff" in diffusers_name:
|
||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
|
||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
else:
|
||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
if is_unet_dora_lora:
|
||||
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
|
||||
unet_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
|
||||
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
|
||||
else:
|
||||
key_to_replace = "lora_te2_"
|
||||
|
||||
diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
|
||||
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
||||
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
||||
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
||||
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
||||
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
||||
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
||||
if "self_attn" in diffusers_name:
|
||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||
te_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
else:
|
||||
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
elif "mlp" in diffusers_name:
|
||||
# Be aware that this is the new diffusers convention and the rest of the code might
|
||||
# not utilize it yet.
|
||||
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||
te_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
else:
|
||||
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
||||
dora_scale_key_to_replace_te = (
|
||||
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
|
||||
)
|
||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||
te_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
elif lora_name.startswith("lora_te2_"):
|
||||
te2_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
|
||||
# Rename the alphas so that they can be mapped appropriately.
|
||||
if lora_name_alpha in state_dict:
|
||||
alpha = state_dict.pop(lora_name_alpha).item()
|
||||
if lora_name_alpha.startswith("lora_unet_"):
|
||||
prefix = "unet."
|
||||
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
|
||||
prefix = "text_encoder."
|
||||
else:
|
||||
prefix = "text_encoder_2."
|
||||
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
|
||||
network_alphas.update({new_name: alpha})
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}")
|
||||
|
||||
logger.info("Kohya-style checkpoint detected.")
|
||||
logger.info("Non-diffusers LoRA checkpoint detected.")
|
||||
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
||||
te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
|
||||
te2_state_dict = (
|
||||
{f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
|
||||
if len(te2_state_dict) > 0
|
||||
else None
|
||||
)
|
||||
if te2_state_dict is not None:
|
||||
|
||||
if te2_state_dict:
|
||||
te2_state_dict = {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
|
||||
te_state_dict.update(te2_state_dict)
|
||||
|
||||
new_state_dict = {**unet_state_dict, **te_state_dict}
|
||||
|
||||
@@ -340,7 +340,7 @@ class FromSingleFileMixin:
|
||||
deprecate("original_config_file", "1.0.0", deprecation_message)
|
||||
original_config = original_config_file
|
||||
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
|
||||
@@ -166,7 +166,7 @@ class FromOriginalModelMixin:
|
||||
"`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
|
||||
)
|
||||
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
|
||||
Reference in New Issue
Block a user