mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-17 09:54:41 +08:00
Compare commits
4 Commits
sf-comfy-l
...
lora-old-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7b95976b3 | ||
|
|
20d5f88dd0 | ||
|
|
a86f0b00cd | ||
|
|
449353298c |
13
.github/workflows/pr_tests.yml
vendored
13
.github/workflows/pr_tests.yml
vendored
@@ -34,11 +34,6 @@ jobs:
|
|||||||
runner: docker-cpu
|
runner: docker-cpu
|
||||||
image: diffusers/diffusers-pytorch-cpu
|
image: diffusers/diffusers-pytorch-cpu
|
||||||
report: torch_cpu_models_schedulers
|
report: torch_cpu_models_schedulers
|
||||||
- name: LoRA
|
|
||||||
framework: lora
|
|
||||||
runner: docker-cpu
|
|
||||||
image: diffusers/diffusers-pytorch-cpu
|
|
||||||
report: torch_cpu_lora
|
|
||||||
- name: Fast Flax CPU tests
|
- name: Fast Flax CPU tests
|
||||||
framework: flax
|
framework: flax
|
||||||
runner: docker-cpu
|
runner: docker-cpu
|
||||||
@@ -94,14 +89,6 @@ jobs:
|
|||||||
--make-reports=tests_${{ matrix.config.report }} \
|
--make-reports=tests_${{ matrix.config.report }} \
|
||||||
tests/models tests/schedulers tests/others
|
tests/models tests/schedulers tests/others
|
||||||
|
|
||||||
- name: Run fast PyTorch LoRA CPU tests
|
|
||||||
if: ${{ matrix.config.framework == 'lora' }}
|
|
||||||
run: |
|
|
||||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
|
||||||
-s -v -k "not Flax and not Onnx and not Dependency" \
|
|
||||||
--make-reports=tests_${{ matrix.config.report }} \
|
|
||||||
tests/lora
|
|
||||||
|
|
||||||
- name: Run fast Flax TPU tests
|
- name: Run fast Flax TPU tests
|
||||||
if: ${{ matrix.config.framework == 'flax' }}
|
if: ${{ matrix.config.framework == 'flax' }}
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
from contextlib import nullcontext
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, List, Optional, Union
|
from typing import Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
@@ -26,7 +25,7 @@ from packaging import version
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
USE_PEFT_BACKEND,
|
USE_PEFT_BACKEND,
|
||||||
_get_model_file,
|
_get_model_file,
|
||||||
@@ -34,7 +33,6 @@ from ..utils import (
|
|||||||
convert_state_dict_to_peft,
|
convert_state_dict_to_peft,
|
||||||
convert_unet_state_dict_to_peft,
|
convert_unet_state_dict_to_peft,
|
||||||
delete_adapter_layers,
|
delete_adapter_layers,
|
||||||
deprecate,
|
|
||||||
get_adapter_name,
|
get_adapter_name,
|
||||||
get_peft_kwargs,
|
get_peft_kwargs,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
@@ -51,10 +49,9 @@ from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_
|
|||||||
if is_transformers_available():
|
if is_transformers_available():
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
from ..models.lora import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
|
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
from accelerate import init_empty_weights
|
|
||||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -106,6 +103,9 @@ class LoraLoaderMixin:
|
|||||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||||
`default_{i}` where i is the total number of adapters being loaded.
|
`default_{i}` where i is the total number of adapters being loaded.
|
||||||
"""
|
"""
|
||||||
|
if not USE_PEFT_BACKEND:
|
||||||
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
|
|
||||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||||
|
|
||||||
@@ -397,16 +397,17 @@ class LoraLoaderMixin:
|
|||||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||||
`default_{i}` where i is the total number of adapters being loaded.
|
`default_{i}` where i is the total number of adapters being loaded.
|
||||||
"""
|
"""
|
||||||
|
if not USE_PEFT_BACKEND:
|
||||||
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
|
|
||||||
|
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||||
|
|
||||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||||
# their prefixes.
|
# their prefixes.
|
||||||
keys = list(state_dict.keys())
|
keys = list(state_dict.keys())
|
||||||
|
|
||||||
if all(key.startswith("unet.unet") for key in keys):
|
|
||||||
deprecation_message = "Keys starting with 'unet.unet' are deprecated."
|
|
||||||
deprecate("unet.unet keys", "0.27", deprecation_message)
|
|
||||||
|
|
||||||
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
|
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
|
||||||
# Load the layers corresponding to UNet.
|
# Load the layers corresponding to UNet.
|
||||||
logger.info(f"Loading {cls.unet_name}.")
|
logger.info(f"Loading {cls.unet_name}.")
|
||||||
@@ -427,9 +428,7 @@ class LoraLoaderMixin:
|
|||||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||||
logger.warn(warn_message)
|
logger.warn(warn_message)
|
||||||
|
|
||||||
if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
|
if len(state_dict.keys()) > 0:
|
||||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
|
||||||
|
|
||||||
if adapter_name in getattr(unet, "peft_config", {}):
|
if adapter_name in getattr(unet, "peft_config", {}):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
|
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
|
||||||
@@ -518,6 +517,11 @@ class LoraLoaderMixin:
|
|||||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||||
`default_{i}` where i is the total number of adapters being loaded.
|
`default_{i}` where i is the total number of adapters being loaded.
|
||||||
"""
|
"""
|
||||||
|
if not USE_PEFT_BACKEND:
|
||||||
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
|
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||||
|
|
||||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||||
@@ -539,34 +543,21 @@ class LoraLoaderMixin:
|
|||||||
rank = {}
|
rank = {}
|
||||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||||
|
|
||||||
if USE_PEFT_BACKEND:
|
# convert state dict
|
||||||
# convert state dict
|
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
|
||||||
|
|
||||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||||
rank_key = f"{name}.out_proj.lora_B.weight"
|
rank_key = f"{name}.out_proj.lora_B.weight"
|
||||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||||
|
|
||||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||||
if patch_mlp:
|
if patch_mlp:
|
||||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||||
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
|
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
|
||||||
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
|
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
|
||||||
|
|
||||||
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
||||||
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
||||||
else:
|
|
||||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
|
||||||
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
|
|
||||||
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
|
|
||||||
|
|
||||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
|
||||||
if patch_mlp:
|
|
||||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
|
||||||
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
|
|
||||||
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
|
|
||||||
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
|
||||||
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
|
||||||
|
|
||||||
if network_alphas is not None:
|
if network_alphas is not None:
|
||||||
alpha_keys = [
|
alpha_keys = [
|
||||||
@@ -576,84 +567,25 @@ class LoraLoaderMixin:
|
|||||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
if USE_PEFT_BACKEND:
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||||
from peft import LoraConfig
|
lora_config = LoraConfig(**lora_config_kwargs)
|
||||||
|
|
||||||
lora_config_kwargs = get_peft_kwargs(
|
# adapter_name
|
||||||
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
|
if adapter_name is None:
|
||||||
)
|
adapter_name = get_adapter_name(text_encoder)
|
||||||
lora_config = LoraConfig(**lora_config_kwargs)
|
|
||||||
|
|
||||||
# adapter_name
|
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||||
if adapter_name is None:
|
|
||||||
adapter_name = get_adapter_name(text_encoder)
|
|
||||||
|
|
||||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
# inject LoRA layers and load the state dict
|
||||||
|
# in transformers we automatically check whether the adapter name is already in use or not
|
||||||
|
text_encoder.load_adapter(
|
||||||
|
adapter_name=adapter_name,
|
||||||
|
adapter_state_dict=text_encoder_lora_state_dict,
|
||||||
|
peft_config=lora_config,
|
||||||
|
)
|
||||||
|
|
||||||
# inject LoRA layers and load the state dict
|
# scale LoRA layers with `lora_scale`
|
||||||
# in transformers we automatically check whether the adapter name is already in use or not
|
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||||
text_encoder.load_adapter(
|
|
||||||
adapter_name=adapter_name,
|
|
||||||
adapter_state_dict=text_encoder_lora_state_dict,
|
|
||||||
peft_config=lora_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
# scale LoRA layers with `lora_scale`
|
|
||||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
|
||||||
else:
|
|
||||||
cls._modify_text_encoder(
|
|
||||||
text_encoder,
|
|
||||||
lora_scale,
|
|
||||||
network_alphas,
|
|
||||||
rank=rank,
|
|
||||||
patch_mlp=patch_mlp,
|
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
|
||||||
)
|
|
||||||
|
|
||||||
is_pipeline_offloaded = _pipeline is not None and any(
|
|
||||||
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook")
|
|
||||||
for c in _pipeline.components.values()
|
|
||||||
)
|
|
||||||
if is_pipeline_offloaded and low_cpu_mem_usage:
|
|
||||||
low_cpu_mem_usage = True
|
|
||||||
logger.info(
|
|
||||||
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
|
|
||||||
)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage:
|
|
||||||
device = next(iter(text_encoder_lora_state_dict.values())).device
|
|
||||||
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
|
|
||||||
unexpected_keys = load_model_dict_into_meta(
|
|
||||||
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
load_state_dict_results = text_encoder.load_state_dict(
|
|
||||||
text_encoder_lora_state_dict, strict=False
|
|
||||||
)
|
|
||||||
unexpected_keys = load_state_dict_results.unexpected_keys
|
|
||||||
|
|
||||||
if len(unexpected_keys) != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# <Unsafe code
|
|
||||||
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
|
|
||||||
# Now we remove any existing hooks to
|
|
||||||
is_model_cpu_offload = False
|
|
||||||
is_sequential_cpu_offload = False
|
|
||||||
if _pipeline is not None:
|
|
||||||
for _, component in _pipeline.components.items():
|
|
||||||
if isinstance(component, torch.nn.Module):
|
|
||||||
if hasattr(component, "_hf_hook"):
|
|
||||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
|
||||||
is_sequential_cpu_offload = isinstance(
|
|
||||||
getattr(component, "_hf_hook"), AlignDevicesHook
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
|
||||||
)
|
|
||||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
|
||||||
|
|
||||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||||
|
|
||||||
@@ -689,6 +621,8 @@ class LoraLoaderMixin:
|
|||||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||||
`default_{i}` where i is the total number of adapters being loaded.
|
`default_{i}` where i is the total number of adapters being loaded.
|
||||||
"""
|
"""
|
||||||
|
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||||
|
|
||||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||||
|
|
||||||
keys = list(state_dict.keys())
|
keys = list(state_dict.keys())
|
||||||
@@ -705,8 +639,6 @@ class LoraLoaderMixin:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(state_dict.keys()) > 0:
|
if len(state_dict.keys()) > 0:
|
||||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
|
||||||
|
|
||||||
if adapter_name in getattr(transformer, "peft_config", {}):
|
if adapter_name in getattr(transformer, "peft_config", {}):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
||||||
@@ -754,118 +686,20 @@ class LoraLoaderMixin:
|
|||||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||||
|
|
||||||
def _remove_text_encoder_monkey_patch(self):
|
def _remove_text_encoder_monkey_patch(self):
|
||||||
if USE_PEFT_BACKEND:
|
remove_method = recurse_remove_peft_layers
|
||||||
remove_method = recurse_remove_peft_layers
|
|
||||||
else:
|
|
||||||
remove_method = self._remove_text_encoder_monkey_patch_classmethod
|
|
||||||
|
|
||||||
if hasattr(self, "text_encoder"):
|
if hasattr(self, "text_encoder"):
|
||||||
remove_method(self.text_encoder)
|
remove_method(self.text_encoder)
|
||||||
|
|
||||||
# In case text encoder have no Lora attached
|
# In case text encoder have no Lora attached
|
||||||
if USE_PEFT_BACKEND and getattr(self.text_encoder, "peft_config", None) is not None:
|
if getattr(self.text_encoder, "peft_config", None) is not None:
|
||||||
del self.text_encoder.peft_config
|
del self.text_encoder.peft_config
|
||||||
self.text_encoder._hf_peft_config_loaded = None
|
self.text_encoder._hf_peft_config_loaded = None
|
||||||
|
|
||||||
if hasattr(self, "text_encoder_2"):
|
if hasattr(self, "text_encoder_2"):
|
||||||
remove_method(self.text_encoder_2)
|
remove_method(self.text_encoder_2)
|
||||||
if USE_PEFT_BACKEND:
|
if getattr(self.text_encoder_2, "peft_config", None) is not None:
|
||||||
del self.text_encoder_2.peft_config
|
del self.text_encoder_2.peft_config
|
||||||
self.text_encoder_2._hf_peft_config_loaded = None
|
self.text_encoder_2._hf_peft_config_loaded = None
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
|
|
||||||
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.27", LORA_DEPRECATION_MESSAGE)
|
|
||||||
|
|
||||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
|
||||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
|
||||||
attn_module.q_proj.lora_linear_layer = None
|
|
||||||
attn_module.k_proj.lora_linear_layer = None
|
|
||||||
attn_module.v_proj.lora_linear_layer = None
|
|
||||||
attn_module.out_proj.lora_linear_layer = None
|
|
||||||
|
|
||||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
|
||||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
|
||||||
mlp_module.fc1.lora_linear_layer = None
|
|
||||||
mlp_module.fc2.lora_linear_layer = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _modify_text_encoder(
|
|
||||||
cls,
|
|
||||||
text_encoder,
|
|
||||||
lora_scale=1,
|
|
||||||
network_alphas=None,
|
|
||||||
rank: Union[Dict[str, int], int] = 4,
|
|
||||||
dtype=None,
|
|
||||||
patch_mlp=False,
|
|
||||||
low_cpu_mem_usage=False,
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Monkey-patches the forward passes of attention modules of the text encoder.
|
|
||||||
"""
|
|
||||||
deprecate("_modify_text_encoder", "0.27", LORA_DEPRECATION_MESSAGE)
|
|
||||||
|
|
||||||
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
|
|
||||||
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
|
|
||||||
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
|
||||||
with ctx():
|
|
||||||
model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype)
|
|
||||||
|
|
||||||
lora_parameters.extend(model.lora_linear_layer.parameters())
|
|
||||||
return model
|
|
||||||
|
|
||||||
# First, remove any monkey-patch that might have been applied before
|
|
||||||
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
|
|
||||||
|
|
||||||
lora_parameters = []
|
|
||||||
network_alphas = {} if network_alphas is None else network_alphas
|
|
||||||
is_network_alphas_populated = len(network_alphas) > 0
|
|
||||||
|
|
||||||
for name, attn_module in text_encoder_attn_modules(text_encoder):
|
|
||||||
query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None)
|
|
||||||
key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None)
|
|
||||||
value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
|
|
||||||
out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)
|
|
||||||
|
|
||||||
if isinstance(rank, dict):
|
|
||||||
current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight")
|
|
||||||
else:
|
|
||||||
current_rank = rank
|
|
||||||
|
|
||||||
attn_module.q_proj = create_patched_linear_lora(
|
|
||||||
attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters
|
|
||||||
)
|
|
||||||
attn_module.k_proj = create_patched_linear_lora(
|
|
||||||
attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters
|
|
||||||
)
|
|
||||||
attn_module.v_proj = create_patched_linear_lora(
|
|
||||||
attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters
|
|
||||||
)
|
|
||||||
attn_module.out_proj = create_patched_linear_lora(
|
|
||||||
attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters
|
|
||||||
)
|
|
||||||
|
|
||||||
if patch_mlp:
|
|
||||||
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
|
|
||||||
fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha", None)
|
|
||||||
fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha", None)
|
|
||||||
|
|
||||||
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
|
|
||||||
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
|
|
||||||
|
|
||||||
mlp_module.fc1 = create_patched_linear_lora(
|
|
||||||
mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters
|
|
||||||
)
|
|
||||||
mlp_module.fc2 = create_patched_linear_lora(
|
|
||||||
mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_network_alphas_populated and len(network_alphas) > 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return lora_parameters
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def save_lora_weights(
|
def save_lora_weights(
|
||||||
cls,
|
cls,
|
||||||
@@ -1039,6 +873,8 @@ class LoraLoaderMixin:
|
|||||||
pipeline.fuse_lora(lora_scale=0.7)
|
pipeline.fuse_lora(lora_scale=0.7)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||||
|
|
||||||
if fuse_unet or fuse_text_encoder:
|
if fuse_unet or fuse_text_encoder:
|
||||||
self.num_fused_loras += 1
|
self.num_fused_loras += 1
|
||||||
if self.num_fused_loras > 1:
|
if self.num_fused_loras > 1:
|
||||||
@@ -1050,52 +886,26 @@ class LoraLoaderMixin:
|
|||||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||||
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
||||||
|
|
||||||
if USE_PEFT_BACKEND:
|
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
merge_kwargs = {"safe_merge": safe_fusing}
|
||||||
|
|
||||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
for module in text_encoder.modules():
|
||||||
merge_kwargs = {"safe_merge": safe_fusing}
|
if isinstance(module, BaseTunerLayer):
|
||||||
|
if lora_scale != 1.0:
|
||||||
|
module.scale_layer(lora_scale)
|
||||||
|
|
||||||
for module in text_encoder.modules():
|
# For BC with previous PEFT versions, we need to check the signature
|
||||||
if isinstance(module, BaseTunerLayer):
|
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||||
if lora_scale != 1.0:
|
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||||
module.scale_layer(lora_scale)
|
if "adapter_names" in supported_merge_kwargs:
|
||||||
|
merge_kwargs["adapter_names"] = adapter_names
|
||||||
|
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"The `adapter_names` argument is not supported with your PEFT version. "
|
||||||
|
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
||||||
|
)
|
||||||
|
|
||||||
# For BC with previous PEFT versions, we need to check the signature
|
module.merge(**merge_kwargs)
|
||||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
|
||||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
|
||||||
if "adapter_names" in supported_merge_kwargs:
|
|
||||||
merge_kwargs["adapter_names"] = adapter_names
|
|
||||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"The `adapter_names` argument is not supported with your PEFT version. "
|
|
||||||
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
|
||||||
)
|
|
||||||
|
|
||||||
module.merge(**merge_kwargs)
|
|
||||||
|
|
||||||
else:
|
|
||||||
deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
|
|
||||||
|
|
||||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, **kwargs):
|
|
||||||
if "adapter_names" in kwargs and kwargs["adapter_names"] is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"The `adapter_names` argument is not supported in your environment. Please switch to PEFT "
|
|
||||||
"backend to use this argument by installing latest PEFT and transformers."
|
|
||||||
" `pip install -U peft transformers`"
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
|
||||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
|
||||||
attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
|
|
||||||
attn_module.k_proj._fuse_lora(lora_scale, safe_fusing)
|
|
||||||
attn_module.v_proj._fuse_lora(lora_scale, safe_fusing)
|
|
||||||
attn_module.out_proj._fuse_lora(lora_scale, safe_fusing)
|
|
||||||
|
|
||||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
|
||||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
|
||||||
mlp_module.fc1._fuse_lora(lora_scale, safe_fusing)
|
|
||||||
mlp_module.fc2._fuse_lora(lora_scale, safe_fusing)
|
|
||||||
|
|
||||||
if fuse_text_encoder:
|
if fuse_text_encoder:
|
||||||
if hasattr(self, "text_encoder"):
|
if hasattr(self, "text_encoder"):
|
||||||
@@ -1120,40 +930,18 @@ class LoraLoaderMixin:
|
|||||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||||
LoRA parameters then it won't have any effect.
|
LoRA parameters then it won't have any effect.
|
||||||
"""
|
"""
|
||||||
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||||
|
|
||||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||||
if unfuse_unet:
|
if unfuse_unet:
|
||||||
if not USE_PEFT_BACKEND:
|
for module in unet.modules():
|
||||||
unet.unfuse_lora()
|
if isinstance(module, BaseTunerLayer):
|
||||||
else:
|
module.unmerge()
|
||||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
|
||||||
|
|
||||||
for module in unet.modules():
|
def unfuse_text_encoder_lora(text_encoder):
|
||||||
if isinstance(module, BaseTunerLayer):
|
for module in text_encoder.modules():
|
||||||
module.unmerge()
|
if isinstance(module, BaseTunerLayer):
|
||||||
|
module.unmerge()
|
||||||
if USE_PEFT_BACKEND:
|
|
||||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
|
||||||
|
|
||||||
def unfuse_text_encoder_lora(text_encoder):
|
|
||||||
for module in text_encoder.modules():
|
|
||||||
if isinstance(module, BaseTunerLayer):
|
|
||||||
module.unmerge()
|
|
||||||
|
|
||||||
else:
|
|
||||||
deprecate("unfuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
|
|
||||||
|
|
||||||
def unfuse_text_encoder_lora(text_encoder):
|
|
||||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
|
||||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
|
||||||
attn_module.q_proj._unfuse_lora()
|
|
||||||
attn_module.k_proj._unfuse_lora()
|
|
||||||
attn_module.v_proj._unfuse_lora()
|
|
||||||
attn_module.out_proj._unfuse_lora()
|
|
||||||
|
|
||||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
|
||||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
|
||||||
mlp_module.fc1._unfuse_lora()
|
|
||||||
mlp_module.fc2._unfuse_lora()
|
|
||||||
|
|
||||||
if unfuse_text_encoder:
|
if unfuse_text_encoder:
|
||||||
if hasattr(self, "text_encoder"):
|
if hasattr(self, "text_encoder"):
|
||||||
@@ -1434,6 +1222,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
|||||||
kwargs (`dict`, *optional*):
|
kwargs (`dict`, *optional*):
|
||||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||||
"""
|
"""
|
||||||
|
if not USE_PEFT_BACKEND:
|
||||||
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
|
|
||||||
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||||
# it here explicitly to be able to tell that it's coming from an SDXL
|
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||||
# pipeline.
|
# pipeline.
|
||||||
@@ -1538,17 +1329,13 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _remove_text_encoder_monkey_patch(self):
|
def _remove_text_encoder_monkey_patch(self):
|
||||||
if USE_PEFT_BACKEND:
|
recurse_remove_peft_layers(self.text_encoder)
|
||||||
recurse_remove_peft_layers(self.text_encoder)
|
# TODO: @younesbelkada handle this in transformers side
|
||||||
# TODO: @younesbelkada handle this in transformers side
|
if getattr(self.text_encoder, "peft_config", None) is not None:
|
||||||
if getattr(self.text_encoder, "peft_config", None) is not None:
|
del self.text_encoder.peft_config
|
||||||
del self.text_encoder.peft_config
|
self.text_encoder._hf_peft_config_loaded = None
|
||||||
self.text_encoder._hf_peft_config_loaded = None
|
|
||||||
|
|
||||||
recurse_remove_peft_layers(self.text_encoder_2)
|
recurse_remove_peft_layers(self.text_encoder_2)
|
||||||
if getattr(self.text_encoder_2, "peft_config", None) is not None:
|
if getattr(self.text_encoder_2, "peft_config", None) is not None:
|
||||||
del self.text_encoder_2.peft_config
|
del self.text_encoder_2.peft_config
|
||||||
self.text_encoder_2._hf_peft_config_loaded = None
|
self.text_encoder_2._hf_peft_config_loaded = None
|
||||||
else:
|
|
||||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
|
||||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ..utils import logging
|
from ..utils import deprecate, logging
|
||||||
from ..utils.import_utils import is_transformers_available
|
from ..utils.import_utils import is_transformers_available
|
||||||
|
|
||||||
|
|
||||||
@@ -82,6 +82,9 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
|
|||||||
|
|
||||||
class PatchedLoraProjection(torch.nn.Module):
|
class PatchedLoraProjection(torch.nn.Module):
|
||||||
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
|
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
|
||||||
|
deprecation_message = "Use of `PatchedLoraProjection` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||||
|
deprecate("PatchedLoraProjection", "1.0.0", deprecation_message)
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
from ..models.lora import LoRALinearLayer
|
from ..models.lora import LoRALinearLayer
|
||||||
|
|
||||||
@@ -293,10 +296,16 @@ class LoRACompatibleConv(nn.Conv2d):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
|
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
|
||||||
|
deprecation_message = "Use of `LoRACompatibleConv` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||||
|
deprecate("LoRACompatibleConv", "1.0.0", deprecation_message)
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.lora_layer = lora_layer
|
self.lora_layer = lora_layer
|
||||||
|
|
||||||
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
|
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
|
||||||
|
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||||
|
deprecate("set_lora_layer", "1.0.0", deprecation_message)
|
||||||
|
|
||||||
self.lora_layer = lora_layer
|
self.lora_layer = lora_layer
|
||||||
|
|
||||||
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
||||||
@@ -371,10 +380,15 @@ class LoRACompatibleLinear(nn.Linear):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
|
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
|
||||||
|
deprecation_message = "Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||||
|
deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message)
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.lora_layer = lora_layer
|
self.lora_layer = lora_layer
|
||||||
|
|
||||||
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
|
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
|
||||||
|
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||||
|
deprecate("set_lora_layer", "1.0.0", deprecation_message)
|
||||||
self.lora_layer = lora_layer
|
self.lora_layer = lora_layer
|
||||||
|
|
||||||
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,64 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
|
||||||
from diffusers.utils.testing_utils import torch_device
|
|
||||||
|
|
||||||
|
|
||||||
class PEFTLoRALoading(unittest.TestCase):
|
|
||||||
def get_dummy_inputs(self):
|
|
||||||
pipeline_inputs = {
|
|
||||||
"prompt": "A painting of a squirrel eating a burger",
|
|
||||||
"num_inference_steps": 2,
|
|
||||||
"guidance_scale": 6.0,
|
|
||||||
"output_type": "np",
|
|
||||||
"generator": torch.manual_seed(0),
|
|
||||||
}
|
|
||||||
return pipeline_inputs
|
|
||||||
|
|
||||||
def test_stable_diffusion_peft_lora_loading_in_non_peft(self):
|
|
||||||
sd_pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
|
|
||||||
# This LoRA was obtained using similarly as how it's done in the training scripts.
|
|
||||||
# For details on how the LoRA was obtained, refer to:
|
|
||||||
# https://hf.co/datasets/diffusers/notebooks/blob/main/check_logits_with_serialization_peft_lora.py
|
|
||||||
sd_pipe.load_lora_weights("hf-internal-testing/tiny-sd-lora-peft")
|
|
||||||
|
|
||||||
inputs = self.get_dummy_inputs()
|
|
||||||
outputs = sd_pipe(**inputs).images
|
|
||||||
|
|
||||||
predicted_slice = outputs[0, -3:, -3:, -1].flatten()
|
|
||||||
expected_slice = np.array([0.5396, 0.5707, 0.477, 0.4665, 0.5419, 0.4594, 0.4857, 0.4741, 0.4804])
|
|
||||||
|
|
||||||
self.assertTrue(outputs.shape == (1, 64, 64, 3))
|
|
||||||
assert np.allclose(expected_slice, predicted_slice, atol=1e-3, rtol=1e-3)
|
|
||||||
|
|
||||||
def test_stable_diffusion_xl_peft_lora_loading_in_non_peft(self):
|
|
||||||
sd_pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-pipe").to(torch_device)
|
|
||||||
# This LoRA was obtained using similarly as how it's done in the training scripts.
|
|
||||||
sd_pipe.load_lora_weights("hf-internal-testing/tiny-sdxl-lora-peft")
|
|
||||||
|
|
||||||
inputs = self.get_dummy_inputs()
|
|
||||||
outputs = sd_pipe(**inputs).images
|
|
||||||
|
|
||||||
predicted_slice = outputs[0, -3:, -3:, -1].flatten()
|
|
||||||
expected_slice = np.array([0.613, 0.5566, 0.54, 0.4162, 0.4042, 0.4596, 0.5374, 0.5286, 0.5038])
|
|
||||||
|
|
||||||
self.assertTrue(outputs.shape == (1, 64, 64, 3))
|
|
||||||
assert np.allclose(expected_slice, predicted_slice, atol=1e-3, rtol=1e-3)
|
|
||||||
Reference in New Issue
Block a user