mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-26 05:25:49 +08:00
Compare commits
26 Commits
modular-wo
...
requiremen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60e3284003 | ||
|
|
4b843c8430 | ||
|
|
7b43d0e409 | ||
|
|
3879e32254 | ||
|
|
d7a1c31f4f | ||
|
|
29b15f41c7 | ||
|
|
75edff93a0 | ||
|
|
76f51a5e92 | ||
|
|
a88d11bc90 | ||
|
|
a9165eb749 | ||
|
|
eeb3445444 | ||
|
|
5b7d0dfab6 | ||
|
|
1de4402c26 | ||
|
|
024c2b9839 | ||
|
|
35d8d97c02 | ||
|
|
e52cabeff2 | ||
|
|
2c4d73d72d | ||
|
|
046be83946 | ||
|
|
b7fba892f5 | ||
|
|
ecbd907e76 | ||
|
|
d159ae025d | ||
|
|
756a1567f5 | ||
|
|
d2731ababa | ||
|
|
37d3887194 | ||
|
|
127e9a39d8 | ||
|
|
12ceecf077 |
@@ -496,6 +496,8 @@
|
||||
title: Bria 3.2
|
||||
- local: api/pipelines/bria_fibo
|
||||
title: Bria Fibo
|
||||
- local: api/pipelines/bria_fibo_edit
|
||||
title: Bria Fibo Edit
|
||||
- local: api/pipelines/chroma
|
||||
title: Chroma
|
||||
- local: api/pipelines/cogview3
|
||||
|
||||
33
docs/source/en/api/pipelines/bria_fibo_edit.md
Normal file
33
docs/source/en/api/pipelines/bria_fibo_edit.md
Normal file
@@ -0,0 +1,33 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Bria Fibo Edit
|
||||
|
||||
Fibo Edit is an 8B parameter image-to-image model that introduces a new paradigm of structured control, operating on JSON inputs paired with source images to enable deterministic and repeatable editing workflows.
|
||||
Featuring native masking for granular precision, it moves beyond simple prompt-based diffusion to offer explicit, interpretable control optimized for production environments.
|
||||
Its lightweight architecture is designed for deep customization, empowering researchers to build specialized "Edit" models for domain-specific tasks while delivering top-tier aesthetic quality
|
||||
|
||||
## Usage
|
||||
_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/Fibo-Edit), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
|
||||
|
||||
Use the command below to log in:
|
||||
|
||||
```bash
|
||||
hf auth login
|
||||
```
|
||||
|
||||
|
||||
## BriaFiboEditPipeline
|
||||
|
||||
[[autodoc]] BriaFiboEditPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -457,6 +457,7 @@ else:
|
||||
"AuraFlowPipeline",
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"BlipDiffusionPipeline",
|
||||
"BriaFiboEditPipeline",
|
||||
"BriaFiboPipeline",
|
||||
"BriaPipeline",
|
||||
"ChromaImg2ImgPipeline",
|
||||
@@ -1185,6 +1186,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AudioLDMPipeline,
|
||||
AuraFlowPipeline,
|
||||
BriaFiboEditPipeline,
|
||||
BriaFiboPipeline,
|
||||
BriaPipeline,
|
||||
ChromaImg2ImgPipeline,
|
||||
|
||||
@@ -89,8 +89,6 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
|
||||
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
|
||||
# with open(CONFIG, "w") as f:
|
||||
# json.dump(automap, f)
|
||||
with open("requirements.txt", "w") as f:
|
||||
f.write("")
|
||||
|
||||
def _choose_block(self, candidates, chosen=None):
|
||||
for cls, base in candidates:
|
||||
|
||||
@@ -478,7 +478,7 @@ class PeftAdapterMixin:
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
adapter_weights (`Union[List[float], float]`, *optional*):
|
||||
weights (`Union[List[float], float]`, *optional*):
|
||||
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
|
||||
adapters.
|
||||
|
||||
@@ -495,7 +495,7 @@ class PeftAdapterMixin:
|
||||
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
||||
)
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
|
||||
pipeline.unet.set_adapters(["cinematic", "pixel"], weights=[0.5, 0.5])
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
|
||||
@@ -39,6 +39,7 @@ from .modular_pipeline_utils import (
|
||||
InputParam,
|
||||
InsertableDict,
|
||||
OutputParam,
|
||||
_validate_requirements,
|
||||
format_components,
|
||||
format_configs,
|
||||
make_doc_string,
|
||||
@@ -242,6 +243,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
|
||||
config_name = "modular_config.json"
|
||||
model_name = None
|
||||
_requirements: Optional[Dict[str, str]] = None
|
||||
|
||||
@classmethod
|
||||
def _get_signature_keys(cls, obj):
|
||||
@@ -304,6 +306,19 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
config = cls.load_config(pretrained_model_name_or_path)
|
||||
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
trust_remote_code, pretrained_model_name_or_path, has_remote_code
|
||||
)
|
||||
if not (has_remote_code and trust_remote_code):
|
||||
raise ValueError(
|
||||
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
||||
)
|
||||
|
||||
if "requirements" in config and config["requirements"] is not None:
|
||||
_ = _validate_requirements(config["requirements"])
|
||||
|
||||
hub_kwargs_names = [
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
@@ -316,16 +331,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
]
|
||||
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
||||
|
||||
config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs)
|
||||
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
trust_remote_code, pretrained_model_name_or_path, has_remote_code
|
||||
)
|
||||
if not has_remote_code and trust_remote_code:
|
||||
raise ValueError(
|
||||
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
||||
)
|
||||
|
||||
class_ref = config["auto_map"][cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
module_file = module_file + ".py"
|
||||
@@ -350,8 +355,13 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
|
||||
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
|
||||
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
|
||||
|
||||
self.register_to_config(auto_map=auto_map)
|
||||
|
||||
# resolve requirements
|
||||
requirements = _validate_requirements(getattr(self, "_requirements", None))
|
||||
if requirements:
|
||||
self.register_to_config(requirements=requirements)
|
||||
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
config = dict(self.config)
|
||||
self._internal_dict = FrozenDict(config)
|
||||
@@ -1154,6 +1164,14 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
expected_configs=self.expected_configs,
|
||||
)
|
||||
|
||||
@property
|
||||
def _requirements(self) -> Dict[str, str]:
|
||||
requirements = {}
|
||||
for block_name, block in self.sub_blocks.items():
|
||||
if getattr(block, "_requirements", None):
|
||||
requirements[block_name] = block._requirements
|
||||
return requirements
|
||||
|
||||
|
||||
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
|
||||
@@ -19,10 +19,12 @@ from dataclasses import dataclass, field, fields
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from packaging.specifiers import InvalidSpecifier, SpecifierSet
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..loaders.single_file_utils import _is_single_file_path_or_url
|
||||
from ..utils import is_torch_available, logging
|
||||
from ..utils.import_utils import _is_package_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -690,3 +692,86 @@ def make_doc_string(
|
||||
output += format_output_params(outputs, indent_level=2)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _validate_requirements(reqs):
|
||||
if reqs is None:
|
||||
normalized_reqs = {}
|
||||
else:
|
||||
if not isinstance(reqs, dict):
|
||||
raise ValueError(
|
||||
"Requirements must be provided as a dictionary mapping package names to version specifiers."
|
||||
)
|
||||
normalized_reqs = _normalize_requirements(reqs)
|
||||
|
||||
if not normalized_reqs:
|
||||
return {}
|
||||
|
||||
final: Dict[str, str] = {}
|
||||
for req, specified_ver in normalized_reqs.items():
|
||||
req_available, req_actual_ver = _is_package_available(req)
|
||||
if not req_available:
|
||||
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")
|
||||
|
||||
if specified_ver:
|
||||
try:
|
||||
specifier = SpecifierSet(specified_ver)
|
||||
except InvalidSpecifier as err:
|
||||
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err
|
||||
|
||||
if req_actual_ver == "N/A":
|
||||
logger.warning(
|
||||
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
|
||||
)
|
||||
elif not specifier.contains(req_actual_ver, prereleases=True):
|
||||
logger.warning(
|
||||
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
|
||||
)
|
||||
|
||||
final[req] = specified_ver
|
||||
|
||||
return final
|
||||
|
||||
|
||||
def _normalize_requirements(reqs):
|
||||
if not reqs:
|
||||
return {}
|
||||
|
||||
normalized: "OrderedDict[str, str]" = OrderedDict()
|
||||
|
||||
def _accumulate(mapping: Dict[str, Any]):
|
||||
for pkg, spec in mapping.items():
|
||||
if isinstance(spec, dict):
|
||||
# This is recursive because blocks are composable. This way, we can merge requirements
|
||||
# from multiple blocks.
|
||||
_accumulate(spec)
|
||||
continue
|
||||
|
||||
pkg_name = str(pkg).strip()
|
||||
if not pkg_name:
|
||||
raise ValueError("Requirement package name cannot be empty.")
|
||||
|
||||
spec_str = "" if spec is None else str(spec).strip()
|
||||
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
|
||||
spec_str = f"=={spec_str}"
|
||||
|
||||
existing_spec = normalized.get(pkg_name)
|
||||
if existing_spec is not None:
|
||||
if not existing_spec and spec_str:
|
||||
normalized[pkg_name] = spec_str
|
||||
elif existing_spec and spec_str and existing_spec != spec_str:
|
||||
try:
|
||||
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
|
||||
except InvalidSpecifier:
|
||||
logger.warning(
|
||||
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
|
||||
)
|
||||
else:
|
||||
normalized[pkg_name] = str(combined_spec)
|
||||
continue
|
||||
|
||||
normalized[pkg_name] = spec_str
|
||||
|
||||
_accumulate(reqs)
|
||||
|
||||
return normalized
|
||||
|
||||
@@ -129,7 +129,7 @@ else:
|
||||
"AnimateDiffVideoToVideoControlNetPipeline",
|
||||
]
|
||||
_import_structure["bria"] = ["BriaPipeline"]
|
||||
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
|
||||
_import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"]
|
||||
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxControlPipeline",
|
||||
@@ -597,7 +597,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .aura_flow import AuraFlowPipeline
|
||||
from .blip_diffusion import BlipDiffusionPipeline
|
||||
from .bria import BriaPipeline
|
||||
from .bria_fibo import BriaFiboPipeline
|
||||
from .bria_fibo import BriaFiboEditPipeline, BriaFiboPipeline
|
||||
from .chroma import ChromaImg2ImgPipeline, ChromaInpaintPipeline, ChromaPipeline
|
||||
from .chronoedit import ChronoEditPipeline
|
||||
from .cogvideo import (
|
||||
|
||||
@@ -23,6 +23,8 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"]
|
||||
_import_structure["pipeline_bria_fibo_edit"] = ["BriaFiboEditPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -33,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_bria_fibo import BriaFiboPipeline
|
||||
from .pipeline_bria_fibo_edit import BriaFiboEditPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
1133
src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py
Normal file
1133
src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -84,7 +84,6 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
|
||||
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
|
||||
>>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
>>> controlnet = ControlNetModel.from_pretrained(
|
||||
|
||||
@@ -53,7 +53,6 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
>>> from diffusers import HiDreamImagePipeline
|
||||
|
||||
|
||||
>>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
|
||||
... "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
|
||||
@@ -85,7 +85,6 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetPAGImg2ImgPipeline, AutoencoderKL
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
|
||||
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
|
||||
>>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
>>> controlnet = ControlNetModel.from_pretrained(
|
||||
|
||||
@@ -459,7 +459,6 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
||||
>>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline
|
||||
>>> import torch
|
||||
|
||||
|
||||
>>> pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
|
||||
... )
|
||||
|
||||
@@ -587,6 +587,21 @@ class AuraFlowPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class BriaFiboEditPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class BriaFiboPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Callable, Union
|
||||
|
||||
@@ -8,9 +10,16 @@ import torch
|
||||
import diffusers
|
||||
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
|
||||
from ..testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
class ModularPipelineTesterMixin:
|
||||
@@ -335,3 +344,53 @@ class ModularGuiderTesterMixin:
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = torch.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class TestCustomBlockRequirements:
|
||||
def get_dummy_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
# keep two arbitrary deps so that we can test warnings.
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
# keep two dependencies that will be available during testing.
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict(
|
||||
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
|
||||
)
|
||||
return pipe
|
||||
|
||||
def test_custom_requirements_save_load(self):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir)
|
||||
config_path = os.path.join(tmpdir, "modular_config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
requirements = config["requirements"]
|
||||
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == requirements
|
||||
|
||||
def test_warnings(self):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
|
||||
logger.setLevel(30)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.save_pretrained(tmpdir)
|
||||
|
||||
template = "{req} was specified in the requirements but wasn't found in the current environment"
|
||||
msg_xyz = template.format(req="xyz")
|
||||
msg_abc = template.format(req="abc")
|
||||
assert msg_xyz in str(cap_logger.out)
|
||||
assert msg_abc in str(cap_logger.out)
|
||||
|
||||
0
tests/pipelines/bria_fibo_edit/__init__.py
Normal file
0
tests/pipelines/bria_fibo_edit/__init__.py
Normal file
192
tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py
Normal file
192
tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
BriaFiboEditPipeline,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
|
||||
from tests.pipelines.test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = BriaFiboEditPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = False
|
||||
test_group_offloading = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = BriaFiboTransformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=16,
|
||||
num_layers=1,
|
||||
num_single_layers=1,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=2,
|
||||
joint_attention_dim=64,
|
||||
text_encoder_dim=32,
|
||||
pooled_projection_dim=None,
|
||||
axes_dims_rope=[0, 4, 4],
|
||||
)
|
||||
|
||||
vae = AutoencoderKLWan(
|
||||
base_dim=80,
|
||||
decoder_base_dim=128,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
dropout=0.0,
|
||||
in_channels=12,
|
||||
latents_mean=[0.0] * 16,
|
||||
latents_std=[1.0] * 16,
|
||||
is_residual=True,
|
||||
num_res_blocks=2,
|
||||
out_channels=12,
|
||||
patch_size=2,
|
||||
scale_factor_spatial=16,
|
||||
scale_factor_temporal=4,
|
||||
temperal_downsample=[False, True, True],
|
||||
z_dim=16,
|
||||
)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32))
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": '{"text": "A painting of a squirrel eating a burger","edit_instruction": "A painting of a squirrel eating a burger"}',
|
||||
"negative_prompt": "bad, ugly",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"height": 192,
|
||||
"width": 336,
|
||||
"output_type": "np",
|
||||
}
|
||||
image = Image.new("RGB", (336, 192), (255, 255, 255))
|
||||
inputs["image"] = image
|
||||
return inputs
|
||||
|
||||
@unittest.skip(reason="will not be supported due to dim-fusion")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Batching is not supported yet")
|
||||
def test_num_images_per_prompt(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Batching is not supported yet")
|
||||
def test_inference_batch_consistent(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Batching is not supported yet")
|
||||
def test_inference_batch_single_identical(self):
|
||||
pass
|
||||
|
||||
def test_bria_fibo_different_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe = pipe.to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt"] = {"edit_instruction": "a different prompt"}
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
assert max_diff > 1e-6
|
||||
|
||||
def test_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe = pipe.to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
height_width_pairs = [(32, 32), (64, 64), (32, 64)]
|
||||
for height, width in height_width_pairs:
|
||||
expected_height = height
|
||||
expected_width = width
|
||||
|
||||
inputs.update({"height": height, "width": width})
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
assert (output_height, output_width) == (expected_height, expected_width)
|
||||
|
||||
def test_bria_fibo_edit_mask(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe = pipe.to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
mask = Image.fromarray((np.ones((192, 336)) * 255).astype(np.uint8), mode="L")
|
||||
|
||||
inputs.update({"mask": mask})
|
||||
output = pipe(**inputs).images[0]
|
||||
|
||||
assert output.shape == (192, 336, 3)
|
||||
|
||||
def test_bria_fibo_edit_mask_image_size_mismatch(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe = pipe.to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
mask = Image.fromarray((np.ones((64, 64)) * 255).astype(np.uint8), mode="L")
|
||||
|
||||
inputs.update({"mask": mask})
|
||||
with self.assertRaisesRegex(ValueError, "Mask and image must have the same size"):
|
||||
pipe(**inputs)
|
||||
|
||||
def test_bria_fibo_edit_mask_no_image(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe = pipe.to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
mask = Image.fromarray((np.ones((32, 32)) * 255).astype(np.uint8), mode="L")
|
||||
|
||||
# Remove image from inputs if it's there (it shouldn't be by default from get_dummy_inputs)
|
||||
inputs.pop("image", None)
|
||||
inputs.update({"mask": mask})
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "If mask is provided, image must also be provided"):
|
||||
pipe(**inputs)
|
||||
Reference in New Issue
Block a user