mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-21 18:30:38 +08:00
Compare commits
15 Commits
requiremen
...
unet-model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ecfd3b4f99 | ||
|
|
5f8303fe3c | ||
|
|
99de4ceab8 | ||
|
|
c6e6992cdd | ||
|
|
ecbaed793d | ||
|
|
0411da7739 | ||
|
|
ffb254a273 | ||
|
|
ea08148bbd | ||
|
|
3a610814a3 | ||
|
|
ca4a7b0649 | ||
|
|
3371560f1d | ||
|
|
46d44b73d8 | ||
|
|
2b67fb65ef | ||
|
|
0e42a3ff93 | ||
|
|
14439ab793 |
2
.github/workflows/pr_modular_tests.yml
vendored
2
.github/workflows/pr_modular_tests.yml
vendored
@@ -117,7 +117,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,test]"
|
||||
uv pip install -e ".[quality]"
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
6
.github/workflows/pr_tests.yml
vendored
6
.github/workflows/pr_tests.yml
vendored
@@ -114,7 +114,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,test]"
|
||||
uv pip install -e ".[quality]"
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
@@ -191,7 +191,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,test]"
|
||||
uv pip install -e ".[quality]"
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -242,7 +242,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,test]"
|
||||
uv pip install -e ".[quality]"
|
||||
# TODO (sayakpaul, DN6): revisit `--no-deps`
|
||||
uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
|
||||
uv pip install -U tokenizers
|
||||
|
||||
@@ -332,49 +332,4 @@ Make your custom block work with Mellon's visual interface. See the [Mellon Cust
|
||||
Browse the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for inspiration and ready-to-use blocks.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Dependencies
|
||||
|
||||
Declaring package dependencies in custom blocks prevents runtime import errors later on. Diffusers validates the dependencies and returns a warning if a package is missing or incompatible.
|
||||
|
||||
Set a `_requirements` attribute in your block class, mapping package names to version specifiers.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import PipelineBlock
|
||||
|
||||
class MyCustomBlock(PipelineBlock):
|
||||
_requirements = {
|
||||
"transformers": ">=4.44.0",
|
||||
"sentencepiece": ">=0.2.0"
|
||||
}
|
||||
```
|
||||
|
||||
When there are blocks with different requirements, Diffusers merges their requirements.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
|
||||
class BlockA(PipelineBlock):
|
||||
_requirements = {"transformers": ">=4.44.0"}
|
||||
# ...
|
||||
|
||||
class BlockB(PipelineBlock):
|
||||
_requirements = {"sentencepiece": ">=0.2.0"}
|
||||
# ...
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict({
|
||||
"block_a": BlockA,
|
||||
"block_b": BlockB,
|
||||
})
|
||||
```
|
||||
|
||||
When this block is saved with [`~ModularPipeline.save_pretrained`], the requirements are saved to the `modular_config.json` file. When this block is loaded, Diffusers checks each requirement against the current environment. If there is a mismatch or a package isn't found, Diffusers returns the following warning.
|
||||
|
||||
```md
|
||||
# missing package
|
||||
xyz-package was specified in the requirements but wasn't found in the current environment.
|
||||
|
||||
# version mismatch
|
||||
xyz requirement 'specific-version' is not satisfied by the installed version 'actual-version'. Things might work unexpected.
|
||||
```
|
||||
</hfoptions>
|
||||
4
setup.py
4
setup.py
@@ -101,7 +101,6 @@ _deps = [
|
||||
"datasets",
|
||||
"filelock",
|
||||
"flax>=0.4.1",
|
||||
"ftfy",
|
||||
"hf-doc-builder>=0.3.0",
|
||||
"httpx<1.0.0",
|
||||
"huggingface-hub>=0.34.0,<2.0",
|
||||
@@ -222,14 +221,12 @@ extras["docs"] = deps_list("hf-doc-builder")
|
||||
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft", "timm")
|
||||
extras["test"] = deps_list(
|
||||
"compel",
|
||||
"ftfy",
|
||||
"GitPython",
|
||||
"datasets",
|
||||
"Jinja2",
|
||||
"invisible-watermark",
|
||||
"librosa",
|
||||
"parameterized",
|
||||
"protobuf",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
@@ -238,7 +235,6 @@ extras["test"] = deps_list(
|
||||
"sentencepiece",
|
||||
"scipy",
|
||||
"tiktoken",
|
||||
"torchsde",
|
||||
"torchvision",
|
||||
"transformers",
|
||||
"phonemizer",
|
||||
|
||||
@@ -89,6 +89,8 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
|
||||
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
|
||||
# with open(CONFIG, "w") as f:
|
||||
# json.dump(automap, f)
|
||||
with open("requirements.txt", "w") as f:
|
||||
f.write("")
|
||||
|
||||
def _choose_block(self, candidates, chosen=None):
|
||||
for cls, base in candidates:
|
||||
|
||||
@@ -8,7 +8,6 @@ deps = {
|
||||
"datasets": "datasets",
|
||||
"filelock": "filelock",
|
||||
"flax": "flax>=0.4.1",
|
||||
"ftfy": "ftfy",
|
||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||
"httpx": "httpx<1.0.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.34.0,<2.0",
|
||||
|
||||
@@ -1117,26 +1117,6 @@ def _sage_attention_backward_op(
|
||||
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
|
||||
|
||||
|
||||
def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mask: torch.Tensor | None = None):
|
||||
# Skip Attention Mask if all values are 1, `None` mask can speedup the computation
|
||||
if attn_mask is not None and torch.all(attn_mask != 0):
|
||||
attn_mask = None
|
||||
|
||||
# Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
|
||||
# https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md
|
||||
if (
|
||||
attn_mask is not None
|
||||
and attn_mask.ndim == 2
|
||||
and attn_mask.shape[0] == query.shape[0]
|
||||
and attn_mask.shape[1] == key.shape[1]
|
||||
):
|
||||
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
|
||||
attn_mask = ~attn_mask.to(torch.bool)
|
||||
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()
|
||||
|
||||
return attn_mask
|
||||
|
||||
|
||||
def _npu_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
@@ -1154,14 +1134,11 @@ def _npu_attention_forward_op(
|
||||
if return_lse:
|
||||
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
|
||||
|
||||
attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask)
|
||||
|
||||
out = npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query.size(2), # num_heads
|
||||
atten_mask=attn_mask,
|
||||
input_layout="BSND",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
||||
@@ -2691,17 +2668,16 @@ def _native_npu_attention(
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for NPU attention")
|
||||
if return_lse:
|
||||
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
|
||||
if _parallel_config is None:
|
||||
attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask)
|
||||
|
||||
out = npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query.size(2), # num_heads
|
||||
atten_mask=attn_mask,
|
||||
input_layout="BSND",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
||||
@@ -2716,7 +2692,7 @@ def _native_npu_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
None,
|
||||
dropout_p,
|
||||
None,
|
||||
scale,
|
||||
|
||||
@@ -164,11 +164,7 @@ def compute_text_seq_len_from_mask(
|
||||
position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
|
||||
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
|
||||
has_active = encoder_hidden_states_mask.any(dim=1)
|
||||
per_sample_len = torch.where(
|
||||
has_active,
|
||||
active_positions.max(dim=1).values + 1,
|
||||
torch.as_tensor(text_seq_len, device=encoder_hidden_states.device),
|
||||
)
|
||||
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
|
||||
return text_seq_len, per_sample_len, encoder_hidden_states_mask
|
||||
|
||||
|
||||
|
||||
@@ -40,7 +40,6 @@ from .modular_pipeline_utils import (
|
||||
InputParam,
|
||||
InsertableDict,
|
||||
OutputParam,
|
||||
_validate_requirements,
|
||||
combine_inputs,
|
||||
combine_outputs,
|
||||
format_components,
|
||||
@@ -291,7 +290,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
|
||||
config_name = "modular_config.json"
|
||||
model_name = None
|
||||
_requirements: dict[str, str] | None = None
|
||||
_workflow_map = None
|
||||
|
||||
@classmethod
|
||||
@@ -406,9 +404,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
||||
)
|
||||
|
||||
if "requirements" in config and config["requirements"] is not None:
|
||||
_ = _validate_requirements(config["requirements"])
|
||||
|
||||
class_ref = config["auto_map"][cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
module_file = module_file + ".py"
|
||||
@@ -433,13 +428,8 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
|
||||
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
|
||||
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
|
||||
|
||||
self.register_to_config(auto_map=auto_map)
|
||||
|
||||
# resolve requirements
|
||||
requirements = _validate_requirements(getattr(self, "_requirements", None))
|
||||
if requirements:
|
||||
self.register_to_config(requirements=requirements)
|
||||
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
config = dict(self.config)
|
||||
self._internal_dict = FrozenDict(config)
|
||||
@@ -1250,14 +1240,6 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
expected_configs=self.expected_configs,
|
||||
)
|
||||
|
||||
@property
|
||||
def _requirements(self) -> dict[str, str]:
|
||||
requirements = {}
|
||||
for block_name, block in self.sub_blocks.items():
|
||||
if getattr(block, "_requirements", None):
|
||||
requirements[block_name] = block._requirements
|
||||
return requirements
|
||||
|
||||
|
||||
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
|
||||
@@ -22,12 +22,10 @@ from typing import Any, Literal, Type, Union, get_args, get_origin
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
from packaging.specifiers import InvalidSpecifier, SpecifierSet
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..loaders.single_file_utils import _is_single_file_path_or_url
|
||||
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
|
||||
from ..utils.import_utils import _is_package_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -974,89 +972,6 @@ def make_doc_string(
|
||||
return output
|
||||
|
||||
|
||||
def _validate_requirements(reqs):
|
||||
if reqs is None:
|
||||
normalized_reqs = {}
|
||||
else:
|
||||
if not isinstance(reqs, dict):
|
||||
raise ValueError(
|
||||
"Requirements must be provided as a dictionary mapping package names to version specifiers."
|
||||
)
|
||||
normalized_reqs = _normalize_requirements(reqs)
|
||||
|
||||
if not normalized_reqs:
|
||||
return {}
|
||||
|
||||
final: dict[str, str] = {}
|
||||
for req, specified_ver in normalized_reqs.items():
|
||||
req_available, req_actual_ver = _is_package_available(req)
|
||||
if not req_available:
|
||||
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")
|
||||
|
||||
if specified_ver:
|
||||
try:
|
||||
specifier = SpecifierSet(specified_ver)
|
||||
except InvalidSpecifier as err:
|
||||
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err
|
||||
|
||||
if req_actual_ver == "N/A":
|
||||
logger.warning(
|
||||
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
|
||||
)
|
||||
elif not specifier.contains(req_actual_ver, prereleases=True):
|
||||
logger.warning(
|
||||
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
|
||||
)
|
||||
|
||||
final[req] = specified_ver
|
||||
|
||||
return final
|
||||
|
||||
|
||||
def _normalize_requirements(reqs):
|
||||
if not reqs:
|
||||
return {}
|
||||
|
||||
normalized: "OrderedDict[str, str]" = OrderedDict()
|
||||
|
||||
def _accumulate(mapping: dict[str, Any]):
|
||||
for pkg, spec in mapping.items():
|
||||
if isinstance(spec, dict):
|
||||
# This is recursive because blocks are composable. This way, we can merge requirements
|
||||
# from multiple blocks.
|
||||
_accumulate(spec)
|
||||
continue
|
||||
|
||||
pkg_name = str(pkg).strip()
|
||||
if not pkg_name:
|
||||
raise ValueError("Requirement package name cannot be empty.")
|
||||
|
||||
spec_str = "" if spec is None else str(spec).strip()
|
||||
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
|
||||
spec_str = f"=={spec_str}"
|
||||
|
||||
existing_spec = normalized.get(pkg_name)
|
||||
if existing_spec is not None:
|
||||
if not existing_spec and spec_str:
|
||||
normalized[pkg_name] = spec_str
|
||||
elif existing_spec and spec_str and existing_spec != spec_str:
|
||||
try:
|
||||
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
|
||||
except InvalidSpecifier:
|
||||
logger.warning(
|
||||
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
|
||||
)
|
||||
else:
|
||||
normalized[pkg_name] = str(combined_spec)
|
||||
continue
|
||||
|
||||
normalized[pkg_name] = spec_str
|
||||
|
||||
_accumulate(reqs)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def combine_inputs(*named_input_lists: list[tuple[str, list[InputParam]]]) -> list[InputParam]:
|
||||
"""
|
||||
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current
|
||||
|
||||
@@ -465,7 +465,8 @@ class UNetTesterMixin:
|
||||
def test_forward_with_norm_groups(self):
|
||||
if not self._accepts_norm_num_groups(self.model_class):
|
||||
pytest.skip(f"Test not supported for {self.model_class.__name__}")
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
@@ -480,9 +481,9 @@ class UNetTesterMixin:
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
|
||||
@@ -287,8 +287,9 @@ class ModelTesterMixin:
|
||||
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
|
||||
)
|
||||
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
image = model(**inputs_dict, return_dict=False)[0]
|
||||
new_image = new_model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
@@ -308,8 +309,9 @@ class ModelTesterMixin:
|
||||
|
||||
new_model.to(torch_device)
|
||||
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
image = model(**inputs_dict, return_dict=False)[0]
|
||||
new_image = new_model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
@@ -337,8 +339,9 @@ class ModelTesterMixin:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
first = model(**inputs_dict, return_dict=False)[0]
|
||||
second = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
first_flat = first.flatten()
|
||||
second_flat = second.flatten()
|
||||
@@ -395,8 +398,9 @@ class ModelTesterMixin:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
outputs_dict = model(**self.get_dummy_inputs())
|
||||
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
outputs_dict = model(**inputs_dict)
|
||||
outputs_tuple = model(**inputs_dict, return_dict=False)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
@@ -523,8 +527,10 @@ class ModelTesterMixin:
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
|
||||
# Re-create inputs only if they contain a generator (which needs to be reset)
|
||||
if "generator" in inputs_dict:
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
|
||||
@@ -563,8 +569,10 @@ class ModelTesterMixin:
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
|
||||
# Re-create inputs only if they contain a generator (which needs to be reset)
|
||||
if "generator" in inputs_dict:
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load"
|
||||
@@ -614,8 +622,10 @@ class ModelTesterMixin:
|
||||
model_parallel = model_parallel.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_parallel = self.get_dummy_inputs()
|
||||
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
|
||||
# Re-create inputs only if they contain a generator (which needs to be reset)
|
||||
if "generator" in inputs_dict:
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
output_parallel = model_parallel(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"
|
||||
|
||||
@@ -92,9 +92,6 @@ class TorchCompileTesterMixin:
|
||||
model.eval()
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
|
||||
if self.model_class.__name__ == "UNet2DConditionModel":
|
||||
recompile_limit = 2
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(recompile_limit=recompile_limit),
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
@@ -23,10 +24,12 @@ import safetensors.torch
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
from diffusers.utils.import_utils import is_peft_available
|
||||
from diffusers.utils.testing_utils import check_if_dicts_are_equal
|
||||
|
||||
from ...testing_utils import (
|
||||
CaptureLogger,
|
||||
assert_tensors_close,
|
||||
backend_empty_cache,
|
||||
is_lora,
|
||||
@@ -477,10 +480,7 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
with pytest.raises(RuntimeError, match=msg):
|
||||
model.enable_lora_hotswap(target_rank=32)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
import logging
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
@@ -488,21 +488,26 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
msg = (
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
|
||||
logger = diffusers_logging.get_logger("diffusers.loaders.peft")
|
||||
logger.setLevel(logging.WARNING)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in record.message for record in caplog.records)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
|
||||
# check possibility to ignore the error/warning
|
||||
import logging
|
||||
assert msg in str(cap_logger.out), f"Expected warning not found. Captured: {cap_logger.out}"
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
|
||||
logger = diffusers_logging.get_logger("diffusers.loaders.peft")
|
||||
logger.setLevel(logging.WARNING)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
||||
assert len(caplog.records) == 0
|
||||
|
||||
assert cap_logger.out == "", f"Expected no warnings but found: {cap_logger.out}"
|
||||
|
||||
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
||||
# check that wrong argument value raises an error
|
||||
@@ -515,9 +520,6 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
|
||||
|
||||
def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog):
|
||||
# check the error and log
|
||||
import logging
|
||||
|
||||
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
|
||||
target_modules0 = ["to_q"]
|
||||
target_modules1 = ["to_q", "to_k"]
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -26,64 +24,39 @@ from ...testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet1DModel
|
||||
main_input_name = "sample"
|
||||
_LAYERWISE_CASTING_XFAIL_REASON = (
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
)
|
||||
|
||||
|
||||
class UNet1DTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet1DModel testing (standard variant)."""
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
|
||||
time_step = torch.tensor([10] * batch_size).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 14, 16)
|
||||
def model_class(self):
|
||||
return UNet1DModel
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 14, 16)
|
||||
return (14, 16)
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_ema_training(self):
|
||||
pass
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
|
||||
def test_determinism(self):
|
||||
super().test_determinism()
|
||||
|
||||
def test_outputs_equivalence(self):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
def test_from_save_pretrained_variant(self):
|
||||
super().test_from_save_pretrained_variant()
|
||||
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
|
||||
def test_output(self):
|
||||
super().test_output()
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"block_out_channels": (8, 8, 16, 16),
|
||||
"in_channels": 14,
|
||||
"out_channels": 14,
|
||||
@@ -97,18 +70,40 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
|
||||
"act_fn": "swish",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device),
|
||||
"timestep": torch.tensor([10] * batch_size).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestUNet1D(UNet1DTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Not implemented yet for this UNet")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestUNet1DMemory(UNet1DTesterConfig, MemoryTesterMixin):
|
||||
@pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON)
|
||||
def test_layerwise_casting_memory(self):
|
||||
super().test_layerwise_casting_memory()
|
||||
|
||||
|
||||
class TestUNet1DHubLoading(UNet1DTesterConfig):
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet1DModel.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
assert model is not None
|
||||
assert len(loading_info["missing_keys"]) == 0
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
image = model(**self.get_dummy_inputs())
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -131,12 +126,7 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# Not implemented yet for this UNet
|
||||
pass
|
||||
assert torch.allclose(output_slice, expected_output_slice, rtol=1e-3)
|
||||
|
||||
@slow
|
||||
def test_unet_1d_maestro(self):
|
||||
@@ -157,98 +147,29 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
assert (output_sum - 224.0896).abs() < 0.5
|
||||
assert (output_max - 0.0607).abs() < 4e-4
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
),
|
||||
)
|
||||
def test_layerwise_casting_inference(self):
|
||||
super().test_layerwise_casting_inference()
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
),
|
||||
)
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
# =============================================================================
|
||||
# UNet1D RL (Value Function) Model Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet1DModel
|
||||
main_input_name = "sample"
|
||||
class UNet1DRLTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet1DModel testing (RL value function variant)."""
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
|
||||
time_step = torch.tensor([10] * batch_size).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 14, 16)
|
||||
def model_class(self):
|
||||
return UNet1DModel
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 14, 1)
|
||||
return (1,)
|
||||
|
||||
def test_determinism(self):
|
||||
super().test_determinism()
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def test_outputs_equivalence(self):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
def test_from_save_pretrained_variant(self):
|
||||
super().test_from_save_pretrained_variant()
|
||||
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
|
||||
def test_output(self):
|
||||
# UNetRL is a value-function is different output shape
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_ema_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"in_channels": 14,
|
||||
"out_channels": 14,
|
||||
"down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"],
|
||||
@@ -264,18 +185,54 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
"time_embedding_type": "positional",
|
||||
"act_fn": "mish",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device),
|
||||
"timestep": torch.tensor([10] * batch_size).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestUNet1DRL(UNet1DRLTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Not implemented yet for this UNet")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def test_output(self):
|
||||
# UNetRL is a value-function with different output shape (batch, 1)
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert output is not None
|
||||
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
|
||||
class TestUNet1DRLMemory(UNet1DRLTesterConfig, MemoryTesterMixin):
|
||||
@pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON)
|
||||
def test_layerwise_casting_memory(self):
|
||||
super().test_layerwise_casting_memory()
|
||||
|
||||
|
||||
class TestUNet1DRLHubLoading(UNet1DRLTesterConfig):
|
||||
def test_from_pretrained_hub(self):
|
||||
value_function, vf_loading_info = UNet1DModel.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
|
||||
)
|
||||
self.assertIsNotNone(value_function)
|
||||
self.assertEqual(len(vf_loading_info["missing_keys"]), 0)
|
||||
assert value_function is not None
|
||||
assert len(vf_loading_info["missing_keys"]) == 0
|
||||
|
||||
value_function.to(torch_device)
|
||||
image = value_function(**self.dummy_input)
|
||||
image = value_function(**self.get_dummy_inputs())
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -299,31 +256,4 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([165.25] * seq_len)
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# Not implemented yet for this UNet
|
||||
pass
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
),
|
||||
)
|
||||
def test_layerwise_casting_inference(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
),
|
||||
)
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
assert torch.allclose(output, expected_output_slice, rtol=1e-3)
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
|
||||
import gc
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import UNet2DModel
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
@@ -31,39 +30,40 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
main_input_name = "sample"
|
||||
# =============================================================================
|
||||
# Standard UNet2D Model Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class UNet2DTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for standard UNet2DModel testing."""
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
def model_class(self):
|
||||
return UNet2DModel
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"block_out_channels": (4, 8),
|
||||
"norm_num_groups": 2,
|
||||
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
|
||||
@@ -74,11 +74,22 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
"layers_per_block": 2,
|
||||
"sample_size": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
def test_mid_block_attn_groups(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["add_attention"] = True
|
||||
init_dict["attn_norm_num_groups"] = 4
|
||||
@@ -87,13 +98,11 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
self.assertIsNotNone(
|
||||
model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not."
|
||||
assert model.mid_block.attentions[0].group_norm is not None, (
|
||||
"Mid block Attention group norm should exist but does not."
|
||||
)
|
||||
self.assertEqual(
|
||||
model.mid_block.attentions[0].group_norm.num_groups,
|
||||
init_dict["attn_norm_num_groups"],
|
||||
"Mid block Attention group norm does not have the expected number of groups.",
|
||||
assert model.mid_block.attentions[0].group_norm.num_groups == init_dict["attn_norm_num_groups"], (
|
||||
"Mid block Attention group norm does not have the expected number of groups."
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -102,13 +111,15 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_mid_block_none(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
mid_none_init_dict = self.get_init_dict()
|
||||
mid_none_inputs_dict = self.get_dummy_inputs()
|
||||
mid_none_init_dict["mid_block_type"] = None
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
@@ -119,7 +130,7 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
mid_none_model.to(torch_device)
|
||||
mid_none_model.eval()
|
||||
|
||||
self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")
|
||||
assert mid_none_model.mid_block is None, "Mid block should not exist."
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
@@ -133,8 +144,10 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
if isinstance(mid_none_output, dict):
|
||||
mid_none_output = mid_none_output.to_tuple()[0]
|
||||
|
||||
self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")
|
||||
assert not torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different."
|
||||
|
||||
|
||||
class TestUNet2DTraining(UNet2DTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"AttnUpBlock2D",
|
||||
@@ -143,41 +156,32 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
"UpBlock2D",
|
||||
"DownBlock2D",
|
||||
}
|
||||
|
||||
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
|
||||
attention_head_dim = 8
|
||||
block_out_channels = (16, 32)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
|
||||
)
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
main_input_name = "sample"
|
||||
# =============================================================================
|
||||
# UNet2D LDM Model Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class UNet2DLDMTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet2DModel LDM variant testing."""
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 32, 32)
|
||||
def model_class(self):
|
||||
return UNet2DModel
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"sample_size": 32,
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
@@ -187,17 +191,34 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
"down_block_types": ("DownBlock2D", "DownBlock2D"),
|
||||
"up_block_types": ("UpBlock2D", "UpBlock2D"),
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestUNet2DLDMTraining(UNet2DLDMTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
|
||||
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig):
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
assert model is not None
|
||||
assert len(loading_info["missing_keys"]) == 0
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input).sample
|
||||
image = model(**self.get_dummy_inputs()).sample
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -205,7 +226,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
def test_from_pretrained_accelerate(self):
|
||||
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input).sample
|
||||
image = model(**self.get_dummy_inputs()).sample
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -265,44 +286,31 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
|
||||
|
||||
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
|
||||
attention_head_dim = 32
|
||||
block_out_channels = (32, 64)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
|
||||
)
|
||||
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-3)
|
||||
|
||||
|
||||
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
main_input_name = "sample"
|
||||
# =============================================================================
|
||||
# NCSN++ Model Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class NCSNppTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet2DModel NCSN++ variant testing."""
|
||||
|
||||
@property
|
||||
def dummy_input(self, sizes=(32, 32)):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
def model_class(self):
|
||||
return UNet2DModel
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"block_out_channels": [32, 64, 64, 64],
|
||||
"in_channels": 3,
|
||||
"layers_per_block": 1,
|
||||
@@ -324,17 +332,71 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
"SkipUpBlock2D",
|
||||
],
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestNCSNpp(NCSNppTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. "
|
||||
"Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_keep_in_fp32_modules(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. "
|
||||
"Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_from_save_pretrained_dtype_inference(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestNCSNppMemory(NCSNppTesterConfig, MemoryTesterMixin):
|
||||
@pytest.mark.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. "
|
||||
"Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. "
|
||||
"Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestNCSNppTraining(NCSNppTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"UNetMidBlock2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestNCSNppHubLoading(NCSNppTesterConfig):
|
||||
@slow
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
assert model is not None
|
||||
assert len(loading_info["missing_keys"]) == 0
|
||||
|
||||
model.to(torch_device)
|
||||
inputs = self.dummy_input
|
||||
inputs = self.get_dummy_inputs()
|
||||
noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
|
||||
inputs["sample"] = noise
|
||||
image = model(**inputs)
|
||||
@@ -361,7 +423,7 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
expected_output_slice = torch.tensor([-4836.2178, -6487.1470, -3816.8196, -7964.9302, -10966.3037, -20043.5957, 8137.0513, 2340.3328, 544.6056])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
|
||||
|
||||
def test_output_pretrained_ve_large(self):
|
||||
model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
|
||||
@@ -382,35 +444,4 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# not required for this model
|
||||
pass
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"UNetMidBlock2D",
|
||||
}
|
||||
|
||||
block_out_channels = (32, 64, 64, 64)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})
|
||||
|
||||
@unittest.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_layerwise_casting_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
|
||||
|
||||
@@ -20,6 +20,7 @@ import tempfile
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from parameterized import parameterized
|
||||
@@ -52,17 +53,24 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import (
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
IPAdapterTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
UNetTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
from ..testing_utils.lora import check_if_lora_correctly_set
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -82,16 +90,6 @@ def get_unet_lora_config():
|
||||
return unet_lora_config
|
||||
|
||||
|
||||
def check_if_lora_correctly_set(model) -> bool:
|
||||
"""
|
||||
Checks if the LoRA layers are correctly set with peft
|
||||
"""
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def create_ip_adapter_state_dict(model):
|
||||
# "ip_adapter" (cross-attention weights)
|
||||
ip_cross_attn_state_dict = {}
|
||||
@@ -354,34 +352,28 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
|
||||
return custom_diffusion_attn_procs
|
||||
|
||||
|
||||
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
main_input_name = "sample"
|
||||
# We override the items here because the unet under consideration is small.
|
||||
model_split_percents = [0.5, 0.34, 0.4]
|
||||
class UNet2DConditionTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet2DConditionModel testing."""
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
def model_class(self):
|
||||
return UNet2DConditionModel
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
def output_shape(self) -> tuple[int, int, int]:
|
||||
return (4, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 16, 16)
|
||||
def model_split_percents(self) -> list[float]:
|
||||
return [0.5, 0.34, 0.4]
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
"""Return UNet2D model initialization arguments."""
|
||||
return {
|
||||
"block_out_channels": (4, 8),
|
||||
"norm_num_groups": 4,
|
||||
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
@@ -393,26 +385,24 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
"layers_per_block": 1,
|
||||
"sample_size": 16,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
"""Return dummy inputs for UNet2D model."""
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
|
||||
}
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
def test_model_with_attention_head_dim_tuple(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -427,12 +417,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_model_with_use_linear_projection(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["use_linear_projection"] = True
|
||||
|
||||
@@ -446,12 +437,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_model_with_cross_attention_dim_tuple(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["cross_attention_dim"] = (8, 8)
|
||||
|
||||
@@ -465,12 +457,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_model_with_simple_projection(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
batch_size, _, _, sample_size = inputs_dict["sample"].shape
|
||||
|
||||
@@ -489,12 +482,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_model_with_class_embeddings_concat(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
batch_size, _, _, sample_size = inputs_dict["sample"].shape
|
||||
|
||||
@@ -514,12 +508,287 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
|
||||
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
|
||||
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty small,
|
||||
# maybe it's fine that this only works for the unclip use-case.
|
||||
@mark.skip(
|
||||
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
|
||||
)
|
||||
def test_model_xattn_padding(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
cond = inputs_dict["encoder_hidden_states"]
|
||||
with torch.no_grad():
|
||||
full_cond_out = model(**inputs_dict).sample
|
||||
assert full_cond_out is not None
|
||||
|
||||
batch, tokens, _ = cond.shape
|
||||
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
|
||||
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
|
||||
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
|
||||
|
||||
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
|
||||
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
|
||||
assert trunc_mask_out.allclose(keeplast_out), (
|
||||
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
|
||||
)
|
||||
|
||||
def test_pickle(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict).sample
|
||||
|
||||
sample_copy = copy.copy(sample)
|
||||
|
||||
assert (sample - sample_copy).abs().max() < 1e-4
|
||||
|
||||
def test_asymmetrical_unet(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
# Add asymmetry to configs
|
||||
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
|
||||
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
output = model(**inputs_dict).sample
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
|
||||
# Check if input and output shapes are the same
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
|
||||
class TestUNet2DConditionHubLoading(UNet2DConditionTesterConfig):
|
||||
"""Hub checkpoint loading tests for UNet2DConditionModel."""
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local(self):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(
|
||||
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
|
||||
)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
|
||||
class TestUNet2DConditionLoRA(UNet2DConditionTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for UNet2DConditionModel."""
|
||||
|
||||
@require_peft_backend
|
||||
def test_load_attn_procs_raise_warning(self):
|
||||
"""Test that deprecated load_attn_procs method raises FutureWarning."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# forward pass without LoRA
|
||||
with torch.no_grad():
|
||||
non_lora_sample = model(**inputs_dict).sample
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
# forward pass with LoRA
|
||||
with torch.no_grad():
|
||||
lora_sample_1 = model(**inputs_dict).sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
model.unload_lora()
|
||||
|
||||
with pytest.warns(FutureWarning, match="Using the `load_attn_procs\\(\\)` method has been deprecated"):
|
||||
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
# import to still check for the rest of the stuff.
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with torch.no_grad():
|
||||
lora_sample_2 = model(**inputs_dict).sample
|
||||
|
||||
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
|
||||
"LoRA injected UNet should produce different results."
|
||||
)
|
||||
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
|
||||
"Loading from a saved checkpoint should produce identical results."
|
||||
)
|
||||
|
||||
@require_peft_backend
|
||||
def test_save_attn_procs_raise_warning(self):
|
||||
"""Test that deprecated save_attn_procs method raises FutureWarning."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with pytest.warns(FutureWarning, match="Using the `save_attn_procs\\(\\)` method has been deprecated"):
|
||||
model.save_attn_procs(os.path.join(tmpdirname))
|
||||
|
||||
|
||||
class TestUNet2DConditionMemory(UNet2DConditionTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for UNet2DConditionModel."""
|
||||
|
||||
|
||||
class TestUNet2DConditionTraining(UNet2DConditionTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for UNet2DConditionModel."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"UpBlock2D",
|
||||
"Transformer2DModel",
|
||||
"DownBlock2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for UNet2DConditionModel."""
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
def test_model_attention_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -544,7 +813,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert output is not None
|
||||
|
||||
def test_model_sliceable_head_dim(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -562,21 +831,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
for module in model.children():
|
||||
check_sliceable_dim_attr(module)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"UpBlock2D",
|
||||
"Transformer2DModel",
|
||||
"DownBlock2D",
|
||||
}
|
||||
attention_head_dim = (8, 16)
|
||||
block_out_channels = (16, 32)
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
def test_special_attn_proc(self):
|
||||
class AttnEasyProc(torch.nn.Module):
|
||||
def __init__(self, num):
|
||||
@@ -618,7 +872,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
return hidden_states
|
||||
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -645,7 +900,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
]
|
||||
)
|
||||
def test_model_xattn_mask(self, mask_dtype):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16), "block_out_channels": (16, 32)})
|
||||
model.to(torch_device)
|
||||
@@ -675,39 +931,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
"masking the last token from our cond should be equivalent to truncating that token out of the condition"
|
||||
)
|
||||
|
||||
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
|
||||
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
|
||||
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty esoteric.
|
||||
# maybe it's fine that this only works for the unclip use-case.
|
||||
@mark.skip(
|
||||
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
|
||||
)
|
||||
def test_model_xattn_padding(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
cond = inputs_dict["encoder_hidden_states"]
|
||||
with torch.no_grad():
|
||||
full_cond_out = model(**inputs_dict).sample
|
||||
assert full_cond_out is not None
|
||||
|
||||
batch, tokens, _ = cond.shape
|
||||
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
|
||||
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
|
||||
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
|
||||
|
||||
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
|
||||
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
|
||||
assert trunc_mask_out.allclose(keeplast_out), (
|
||||
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
|
||||
)
|
||||
class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
|
||||
"""Custom Diffusion processor tests for UNet2DConditionModel."""
|
||||
|
||||
def test_custom_diffusion_processors(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -733,8 +963,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert (sample1 - sample2).abs().max() < 3e-3
|
||||
|
||||
def test_custom_diffusion_save_load(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -754,7 +984,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname, safe_serialization=False)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")))
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))
|
||||
torch.manual_seed(0)
|
||||
new_model = self.model_class(**init_dict)
|
||||
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
@@ -773,8 +1003,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_custom_diffusion_xformers_on_off(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -798,41 +1028,28 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert (sample - on_sample).abs().max() < 1e-4
|
||||
assert (sample - off_sample).abs().max() < 1e-4
|
||||
|
||||
def test_pickle(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterMixin):
|
||||
"""IP Adapter tests for UNet2DConditionModel."""
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
@property
|
||||
def ip_adapter_processor_cls(self):
|
||||
return (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict).sample
|
||||
def create_ip_adapter_state_dict(self, model):
|
||||
return create_ip_adapter_state_dict(model)
|
||||
|
||||
sample_copy = copy.copy(sample)
|
||||
|
||||
assert (sample - sample_copy).abs().max() < 1e-4
|
||||
|
||||
def test_asymmetrical_unet(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
# Add asymmetry to configs
|
||||
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
|
||||
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
output = model(**inputs_dict).sample
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
|
||||
# Check if input and output shapes are the same
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
|
||||
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
|
||||
# for ip-adapter image_embeds has shape [batch_size, num_image, embed_dim]
|
||||
cross_attention_dim = getattr(model.config, "cross_attention_dim", 8)
|
||||
image_embeds = floats_tensor((batch_size, 1, cross_attention_dim)).to(torch_device)
|
||||
inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]}
|
||||
return inputs_dict
|
||||
|
||||
def test_ip_adapter(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -905,7 +1122,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_ip_adapter_plus(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -977,185 +1195,16 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
|
||||
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
class TestUNet2DConditionModelCompile(UNet2DConditionTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for UNet2DConditionModel."""
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(
|
||||
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
|
||||
)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_peft_backend
|
||||
def test_load_attn_procs_raise_warning(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# forward pass without LoRA
|
||||
with torch.no_grad():
|
||||
non_lora_sample = model(**inputs_dict).sample
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
# forward pass with LoRA
|
||||
with torch.no_grad():
|
||||
lora_sample_1 = model(**inputs_dict).sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
model.unload_lora()
|
||||
|
||||
with self.assertWarns(FutureWarning) as warning:
|
||||
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
warning_message = str(warning.warnings[0].message)
|
||||
assert "Using the `load_attn_procs()` method has been deprecated" in warning_message
|
||||
|
||||
# import to still check for the rest of the stuff.
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with torch.no_grad():
|
||||
lora_sample_2 = model(**inputs_dict).sample
|
||||
|
||||
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
|
||||
"LoRA injected UNet should produce different results."
|
||||
)
|
||||
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
|
||||
"Loading from a saved checkpoint should produce identical results."
|
||||
)
|
||||
|
||||
@require_peft_backend
|
||||
def test_save_attn_procs_raise_warning(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with self.assertWarns(FutureWarning) as warning:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
|
||||
warning_message = str(warning.warnings[0].message)
|
||||
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
|
||||
def test_torch_compile_repeated_blocks(self):
|
||||
return super().test_torch_compile_repeated_blocks(recompile_limit=2)
|
||||
|
||||
|
||||
class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
|
||||
class TestUNet2DConditionModelLoRAHotSwap(UNet2DConditionTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for UNet2DConditionModel."""
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -18,47 +18,44 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.models import ModelMixin, UNet3DConditionModel
|
||||
from diffusers.utils import logging
|
||||
from diffusers import UNet3DConditionModel
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@skip_mps
|
||||
class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet3DConditionModel
|
||||
main_input_name = "sample"
|
||||
class UNet3DConditionTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet3DConditionModel testing."""
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
num_frames = 4
|
||||
sizes = (16, 16)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 4, 16, 16)
|
||||
def model_class(self):
|
||||
return UNet3DConditionModel
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 4, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"block_out_channels": (4, 8),
|
||||
"norm_num_groups": 4,
|
||||
"down_block_types": (
|
||||
@@ -73,27 +70,25 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
"layers_per_block": 1,
|
||||
"sample_size": 16,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
num_frames = 4
|
||||
sizes = (16, 16)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
|
||||
}
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
class TestUNet3DCondition(UNet3DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
# Overriding to set `norm_num_groups` needs to be different for this model.
|
||||
def test_forward_with_norm_groups(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict["block_out_channels"] = (32, 64)
|
||||
init_dict["norm_num_groups"] = 32
|
||||
|
||||
@@ -107,39 +102,74 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
# Overriding since the UNet3D outputs a different structure.
|
||||
@torch.no_grad()
|
||||
def test_determinism(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps" and isinstance(model, ModelMixin):
|
||||
model(**self.dummy_input)
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
first = model(**inputs_dict)
|
||||
if isinstance(first, dict):
|
||||
first = first.sample
|
||||
first = model(**inputs_dict)
|
||||
if isinstance(first, dict):
|
||||
first = first.sample
|
||||
|
||||
second = model(**inputs_dict)
|
||||
if isinstance(second, dict):
|
||||
second = second.sample
|
||||
second = model(**inputs_dict)
|
||||
if isinstance(second, dict):
|
||||
second = second.sample
|
||||
|
||||
out_1 = first.cpu().numpy()
|
||||
out_2 = second.cpu().numpy()
|
||||
out_1 = out_1[~np.isnan(out_1)]
|
||||
out_2 = out_2[~np.isnan(out_2)]
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
assert max_diff <= 1e-5
|
||||
|
||||
def test_feed_forward_chunking(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict["block_out_channels"] = (32, 64)
|
||||
init_dict["norm_num_groups"] = 32
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)[0]
|
||||
|
||||
model.enable_forward_chunking()
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict)[0]
|
||||
|
||||
assert output.shape == output_2.shape, "Shape doesn't match"
|
||||
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
|
||||
|
||||
|
||||
class TestUNet3DConditionAttention(UNet3DConditionTesterConfig, AttentionTesterMixin):
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
def test_model_attention_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = 8
|
||||
@@ -162,22 +192,3 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
def test_feed_forward_chunking(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict["block_out_channels"] = (32, 64)
|
||||
init_dict["norm_num_groups"] = 32
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)[0]
|
||||
|
||||
model.enable_forward_chunking()
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict)[0]
|
||||
|
||||
self.assertEqual(output.shape, output_2.shape, "Shape doesn't match")
|
||||
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
|
||||
|
||||
@@ -13,59 +13,42 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNetControlNetXSModel
|
||||
main_input_name = "sample"
|
||||
class UNetControlNetXSTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNetControlNetXSModel testing."""
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
|
||||
controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device)
|
||||
conditioning_scale = 1
|
||||
|
||||
return {
|
||||
"sample": noise,
|
||||
"timestep": time_step,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"controlnet_cond": controlnet_cond,
|
||||
"conditioning_scale": conditioning_scale,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 16, 16)
|
||||
def model_class(self):
|
||||
return UNetControlNetXSModel
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"sample_size": 16,
|
||||
"down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
"up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
@@ -80,11 +63,23 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
"ctrl_max_norm_num_groups": 2,
|
||||
"ctrl_conditioning_embedding_out_channels": (2, 2),
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
conditioning_image_size = (3, 32, 32)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
|
||||
"controlnet_cond": floats_tensor((batch_size, *conditioning_image_size)).to(torch_device),
|
||||
"conditioning_scale": 1,
|
||||
}
|
||||
|
||||
def get_dummy_unet(self):
|
||||
"""For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
|
||||
"""Build the underlying UNet for tests that construct UNetControlNetXSModel from UNet + Adapter."""
|
||||
return UNet2DConditionModel(
|
||||
block_out_channels=(4, 8),
|
||||
layers_per_block=2,
|
||||
@@ -99,10 +94,16 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
)
|
||||
|
||||
def get_dummy_controlnet_from_unet(self, unet, **kwargs):
|
||||
"""For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
|
||||
# size_ratio and conditioning_embedding_out_channels chosen to keep model small
|
||||
"""Build the ControlNetXS-Adapter from a UNet."""
|
||||
return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs)
|
||||
|
||||
|
||||
class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# UNetControlNetXSModel only supports SD/SDXL with norm_num_groups=32
|
||||
pass
|
||||
|
||||
def test_from_unet(self):
|
||||
unet = self.get_dummy_unet()
|
||||
controlnet = self.get_dummy_controlnet_from_unet(unet)
|
||||
@@ -115,7 +116,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value)
|
||||
|
||||
# # check unet
|
||||
# everything expect down,mid,up blocks
|
||||
# everything except down,mid,up blocks
|
||||
modules_from_unet = [
|
||||
"time_embedding",
|
||||
"conv_in",
|
||||
@@ -152,7 +153,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers")
|
||||
|
||||
# # check controlnet
|
||||
# everything expect down,mid,up blocks
|
||||
# everything except down,mid,up blocks
|
||||
modules_from_controlnet = {
|
||||
"controlnet_cond_embedding": "controlnet_cond_embedding",
|
||||
"conv_in": "ctrl_conv_in",
|
||||
@@ -193,12 +194,12 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
for p in module.parameters():
|
||||
assert p.requires_grad
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
model = UNetControlNetXSModel(**init_dict)
|
||||
model.freeze_unet_params()
|
||||
|
||||
# # check unet
|
||||
# everything expect down,mid,up blocks
|
||||
# everything except down,mid,up blocks
|
||||
modules_from_unet = [
|
||||
model.base_time_embedding,
|
||||
model.base_conv_in,
|
||||
@@ -236,7 +237,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
assert_frozen(u.upsamplers)
|
||||
|
||||
# # check controlnet
|
||||
# everything expect down,mid,up blocks
|
||||
# everything except down,mid,up blocks
|
||||
modules_from_controlnet = [
|
||||
model.controlnet_cond_embedding,
|
||||
model.ctrl_conv_in,
|
||||
@@ -267,16 +268,6 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
for u in model.up_blocks:
|
||||
assert_unfrozen(u.ctrl_to_base)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"Transformer2DModel",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"ControlNetXSCrossAttnDownBlock2D",
|
||||
"ControlNetXSCrossAttnMidBlock2D",
|
||||
"ControlNetXSCrossAttnUpBlock2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@is_flaky
|
||||
def test_forward_no_control(self):
|
||||
unet = self.get_dummy_unet()
|
||||
@@ -287,7 +278,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
unet = unet.to(torch_device)
|
||||
model = model.to(torch_device)
|
||||
|
||||
input_ = self.dummy_input
|
||||
input_ = self.get_dummy_inputs()
|
||||
|
||||
control_specific_input = ["controlnet_cond", "conditioning_scale"]
|
||||
input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input}
|
||||
@@ -312,7 +303,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
model = model.to(torch_device)
|
||||
model_mix_time = model_mix_time.to(torch_device)
|
||||
|
||||
input_ = self.dummy_input
|
||||
input_ = self.get_dummy_inputs()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**input_).sample
|
||||
@@ -320,7 +311,14 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
|
||||
|
||||
assert output.shape == output_mix_time.shape
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups.
|
||||
pass
|
||||
|
||||
class TestUNetControlNetXSTraining(UNetControlNetXSTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"Transformer2DModel",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"ControlNetXSCrossAttnDownBlock2D",
|
||||
"ControlNetXSCrossAttnMidBlock2D",
|
||||
"ControlNetXSCrossAttnUpBlock2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@@ -16,10 +16,10 @@
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import UNetSpatioTemporalConditionModel
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from ...testing_utils import (
|
||||
@@ -28,45 +28,34 @@ from ...testing_utils import (
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@skip_mps
|
||||
class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNetSpatioTemporalConditionModel
|
||||
main_input_name = "sample"
|
||||
class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNetSpatioTemporalConditionModel testing."""
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 2
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
|
||||
|
||||
return {
|
||||
"sample": noise,
|
||||
"timestep": time_step,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"added_time_ids": self._get_add_time_ids(),
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (2, 2, 4, 32, 32)
|
||||
def model_class(self):
|
||||
return UNetSpatioTemporalConditionModel
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
@property
|
||||
def fps(self):
|
||||
return 6
|
||||
@@ -83,8 +72,8 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
|
||||
def addition_time_embed_dim(self):
|
||||
return 32
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"block_out_channels": (32, 64),
|
||||
"down_block_types": (
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
@@ -103,8 +92,23 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
|
||||
"projection_class_embeddings_input_dim": self.addition_time_embed_dim * 3,
|
||||
"addition_time_embed_dim": self.addition_time_embed_dim,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 2
|
||||
num_frames = 2
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
|
||||
|
||||
return {
|
||||
"sample": noise,
|
||||
"timestep": time_step,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"added_time_ids": self._get_add_time_ids(),
|
||||
}
|
||||
|
||||
def _get_add_time_ids(self, do_classifier_free_guidance=True):
|
||||
add_time_ids = [self.fps, self.motion_bucket_id, self.noise_aug_strength]
|
||||
@@ -124,43 +128,15 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
|
||||
|
||||
return add_time_ids
|
||||
|
||||
@unittest.skip("Number of Norm Groups is not configurable")
|
||||
|
||||
class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Number of Norm Groups is not configurable")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Deprecated functionality")
|
||||
def test_model_attention_slicing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported")
|
||||
def test_model_with_use_linear_projection(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported")
|
||||
def test_model_with_simple_projection(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported")
|
||||
def test_model_with_class_embeddings_concat(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
def test_model_with_num_attention_heads_tuple(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["num_attention_heads"] = (8, 16)
|
||||
model = self.model_class(**init_dict)
|
||||
@@ -173,12 +149,13 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_model_with_cross_attention_dim_tuple(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["cross_attention_dim"] = (32, 32)
|
||||
|
||||
@@ -192,27 +169,13 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"TransformerSpatioTemporalModel",
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"DownBlockSpatioTemporal",
|
||||
"UpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
"UNetMidBlockSpatioTemporal",
|
||||
}
|
||||
num_attention_heads = (8, 16)
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, num_attention_heads=num_attention_heads
|
||||
)
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_pickle(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["num_attention_heads"] = (8, 16)
|
||||
|
||||
@@ -225,3 +188,33 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
|
||||
sample_copy = copy.copy(sample)
|
||||
|
||||
assert (sample - sample_copy).abs().max() < 1e-4
|
||||
|
||||
|
||||
class TestUNetSpatioTemporalAttention(UNetSpatioTemporalTesterConfig, AttentionTesterMixin):
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
|
||||
class TestUNetSpatioTemporalTraining(UNetSpatioTemporalTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"TransformerSpatioTemporalModel",
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"DownBlockSpatioTemporal",
|
||||
"UpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
"UNetMidBlockSpatioTemporal",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Callable
|
||||
|
||||
@@ -10,7 +8,6 @@ import torch
|
||||
import diffusers
|
||||
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
ComponentSpec,
|
||||
ConfigSpec,
|
||||
@@ -20,13 +17,7 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
)
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
|
||||
|
||||
|
||||
class ModularPipelineTesterMixin:
|
||||
@@ -409,56 +400,6 @@ class ModularGuiderTesterMixin:
|
||||
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class TestCustomBlockRequirements:
|
||||
def get_dummy_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
# keep two arbitrary deps so that we can test warnings.
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
# keep two dependencies that will be available during testing.
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict(
|
||||
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
|
||||
)
|
||||
return pipe
|
||||
|
||||
def test_custom_requirements_save_load(self):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir)
|
||||
config_path = os.path.join(tmpdir, "modular_config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
requirements = config["requirements"]
|
||||
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == requirements
|
||||
|
||||
def test_warnings(self):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
|
||||
logger.setLevel(30)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.save_pretrained(tmpdir)
|
||||
|
||||
template = "{req} was specified in the requirements but wasn't found in the current environment"
|
||||
msg_xyz = template.format(req="xyz")
|
||||
msg_abc = template.format(req="abc")
|
||||
assert msg_xyz in str(cap_logger.out)
|
||||
assert msg_abc in str(cap_logger.out)
|
||||
|
||||
|
||||
class TestModularModelCardContent:
|
||||
def create_mock_block(self, name="TestBlock", description="Test block description"):
|
||||
class MockBlock:
|
||||
|
||||
Reference in New Issue
Block a user