Compare commits

...

26 Commits

Author SHA1 Message Date
Sayak Paul
60e3284003 Merge branch 'main' into requirements-custom-blocks 2026-01-20 19:10:24 +05:30
Guillaume Besson
4b843c8430 Fix variable name in docstring for PeftAdapterMixin.set_adapters (#13003)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-01-20 15:43:59 +05:30
sayakpaul
7b43d0e409 add tests 2026-01-20 09:29:32 +05:30
Sayak Paul
3879e32254 Merge branch 'main' into requirements-custom-blocks 2026-01-20 08:20:38 +05:30
Gal Davidi
d7a1c31f4f Fibo edit pipeline (#12930)
* Feature: Add BriaFiboEditPipeline to diffusers

* Introduced BriaFiboEditPipeline class with necessary backend requirements.
* Updated import structures in relevant modules to include BriaFiboEditPipeline.
* Ensured compatibility with existing pipelines and type checking.

* Feature: Introduce Bria Fibo Edit Pipeline

* Added BriaFiboEditPipeline class for structured JSON-native image editing.
* Created documentation for the new pipeline in bria_fibo_edit.md.
* Updated import structures to include the new pipeline and its components.
* Added unit tests for the BriaFiboEditPipeline to ensure functionality and correctness.

* Enhancement: Update Bria Fibo Edit Pipeline and Documentation

* Refined the Bria Fibo Edit model description for clarity and detail.
* Added usage instructions for model authentication and login.
* Implemented mask handling functions in the BriaFiboEditPipeline for improved image editing capabilities.
* Updated unit tests to cover new mask functionalities and ensure input validation.
* Adjusted example code in documentation to reflect changes in the pipeline's usage.

* Update Bria Fibo Edit documentation with corrected Hugging Face page link

* add dreambooth training script

* style and quality

* Delete temp.py

* Enhancement: Improve JSON caption validation in DreamBoothDataset

* Updated the clean_json_caption function to handle both string and dictionary inputs for captions.
* Added error handling to raise a ValueError for invalid caption types, ensuring better input validation.

* Add datasets dependency to requirements_fibo_edit.txt

* Add bria_fibo_edit to docs table of contents

* Fix dummy objects ordering

* Fix BriaFiboEditPipeline to use passed generator parameter

The pipeline was ignoring the generator parameter and only using
the seed parameter. This caused non-deterministic outputs in tests
that pass a seeded generator.

* Remove fibo_edit training script and related files

---------

Co-authored-by: kfirbria <kfir@bria.ai>
2026-01-19 22:09:53 +05:30
Sayak Paul
29b15f41c7 [chore] make style to push new changes. (#12998)
make style to push new changes.
2026-01-19 16:02:13 +05:30
sayakpaul
75edff93a0 Revert "make style && make quality"
This reverts commit 76f51a5e92.
2026-01-19 15:35:20 +05:30
sayakpaul
76f51a5e92 make style && make quality 2026-01-19 15:34:29 +05:30
sayakpaul
a88d11bc90 resolve conflicts. 2025-11-06 10:29:24 +05:30
Sayak Paul
a9165eb749 Merge branch 'main' into requirements-custom-blocks 2025-11-03 12:12:08 +05:30
Sayak Paul
eeb3445444 Merge branch 'main' into requirements-custom-blocks 2025-11-01 08:36:16 +05:30
Sayak Paul
5b7d0dfab6 Merge branch 'main' into requirements-custom-blocks 2025-10-29 16:30:46 +05:30
sayakpaul
1de4402c26 up 2025-10-27 13:55:17 +05:30
sayakpaul
024c2b9839 Merge branch 'main' into requirements-custom-blocks 2025-10-27 11:56:00 +05:30
Sayak Paul
35d8d97c02 Merge branch 'main' into requirements-custom-blocks 2025-10-22 21:57:45 +05:30
Sayak Paul
e52cabeff2 Merge branch 'main' into requirements-custom-blocks 2025-10-22 06:23:40 +05:30
Sayak Paul
2c4d73d72d Merge branch 'main' into requirements-custom-blocks 2025-10-21 01:54:38 +05:30
sayakpaul
046be83946 up 2025-10-02 15:43:44 +05:30
Sayak Paul
b7fba892f5 Merge branch 'main' into requirements-custom-blocks 2025-09-23 13:35:49 +05:30
Sayak Paul
ecbd907e76 Merge branch 'main' into requirements-custom-blocks 2025-09-12 15:47:22 +05:30
Sayak Paul
d159ae025d Merge branch 'main' into requirements-custom-blocks 2025-09-02 10:04:22 +05:30
Sayak Paul
756a1567f5 Merge branch 'main' into requirements-custom-blocks 2025-08-29 08:03:00 +02:00
Sayak Paul
d2731ababa Merge branch 'main' into requirements-custom-blocks 2025-08-21 07:59:54 +05:30
sayakpaul
37d3887194 unify. 2025-08-20 12:09:33 +05:30
sayakpaul
127e9a39d8 up 2025-08-20 11:51:15 +05:30
sayakpaul
12ceecf077 feat: implement requirements validation for custom blocks. 2025-08-20 11:04:28 +05:30
18 changed files with 1558 additions and 22 deletions

View File

@@ -496,6 +496,8 @@
title: Bria 3.2
- local: api/pipelines/bria_fibo
title: Bria Fibo
- local: api/pipelines/bria_fibo_edit
title: Bria Fibo Edit
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogview3

View File

@@ -0,0 +1,33 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Bria Fibo Edit
Fibo Edit is an 8B parameter image-to-image model that introduces a new paradigm of structured control, operating on JSON inputs paired with source images to enable deterministic and repeatable editing workflows.
Featuring native masking for granular precision, it moves beyond simple prompt-based diffusion to offer explicit, interpretable control optimized for production environments.
Its lightweight architecture is designed for deep customization, empowering researchers to build specialized "Edit" models for domain-specific tasks while delivering top-tier aesthetic quality
## Usage
_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/Fibo-Edit), fill in the form and accept the gate. Once you are in, you need to login so that your system knows youve accepted the gate._
Use the command below to log in:
```bash
hf auth login
```
## BriaFiboEditPipeline
[[autodoc]] BriaFiboEditPipeline
- all
- __call__

View File

@@ -457,6 +457,7 @@ else:
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"BriaFiboEditPipeline",
"BriaFiboPipeline",
"BriaPipeline",
"ChromaImg2ImgPipeline",
@@ -1185,6 +1186,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
BriaFiboEditPipeline,
BriaFiboPipeline,
BriaPipeline,
ChromaImg2ImgPipeline,

View File

@@ -89,8 +89,6 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
# with open(CONFIG, "w") as f:
# json.dump(automap, f)
with open("requirements.txt", "w") as f:
f.write("")
def _choose_block(self, candidates, chosen=None):
for cls, base in candidates:

View File

@@ -478,7 +478,7 @@ class PeftAdapterMixin:
Args:
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
adapter_weights (`Union[List[float], float]`, *optional*):
weights (`Union[List[float], float]`, *optional*):
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
adapters.
@@ -495,7 +495,7 @@ class PeftAdapterMixin:
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
pipeline.unet.set_adapters(["cinematic", "pixel"], weights=[0.5, 0.5])
```
"""
if not USE_PEFT_BACKEND:

View File

@@ -39,6 +39,7 @@ from .modular_pipeline_utils import (
InputParam,
InsertableDict,
OutputParam,
_validate_requirements,
format_components,
format_configs,
make_doc_string,
@@ -242,6 +243,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
config_name = "modular_config.json"
model_name = None
_requirements: Optional[Dict[str, str]] = None
@classmethod
def _get_signature_keys(cls, obj):
@@ -304,6 +306,19 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
trust_remote_code: bool = False,
**kwargs,
):
config = cls.load_config(pretrained_model_name_or_path)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not (has_remote_code and trust_remote_code):
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
if "requirements" in config and config["requirements"] is not None:
_ = _validate_requirements(config["requirements"])
hub_kwargs_names = [
"cache_dir",
"force_download",
@@ -316,16 +331,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not has_remote_code and trust_remote_code:
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
module_file = module_file + ".py"
@@ -350,8 +355,13 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
self.register_to_config(auto_map=auto_map)
# resolve requirements
requirements = _validate_requirements(getattr(self, "_requirements", None))
if requirements:
self.register_to_config(requirements=requirements)
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
config = dict(self.config)
self._internal_dict = FrozenDict(config)
@@ -1154,6 +1164,14 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
expected_configs=self.expected_configs,
)
@property
def _requirements(self) -> Dict[str, str]:
requirements = {}
for block_name, block in self.sub_blocks.items():
if getattr(block, "_requirements", None):
requirements[block_name] = block._requirements
return requirements
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
"""

View File

@@ -19,10 +19,12 @@ from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
from packaging.specifiers import InvalidSpecifier, SpecifierSet
from ..configuration_utils import ConfigMixin, FrozenDict
from ..loaders.single_file_utils import _is_single_file_path_or_url
from ..utils import is_torch_available, logging
from ..utils.import_utils import _is_package_available
if is_torch_available():
@@ -690,3 +692,86 @@ def make_doc_string(
output += format_output_params(outputs, indent_level=2)
return output
def _validate_requirements(reqs):
if reqs is None:
normalized_reqs = {}
else:
if not isinstance(reqs, dict):
raise ValueError(
"Requirements must be provided as a dictionary mapping package names to version specifiers."
)
normalized_reqs = _normalize_requirements(reqs)
if not normalized_reqs:
return {}
final: Dict[str, str] = {}
for req, specified_ver in normalized_reqs.items():
req_available, req_actual_ver = _is_package_available(req)
if not req_available:
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")
if specified_ver:
try:
specifier = SpecifierSet(specified_ver)
except InvalidSpecifier as err:
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err
if req_actual_ver == "N/A":
logger.warning(
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
)
elif not specifier.contains(req_actual_ver, prereleases=True):
logger.warning(
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
)
final[req] = specified_ver
return final
def _normalize_requirements(reqs):
if not reqs:
return {}
normalized: "OrderedDict[str, str]" = OrderedDict()
def _accumulate(mapping: Dict[str, Any]):
for pkg, spec in mapping.items():
if isinstance(spec, dict):
# This is recursive because blocks are composable. This way, we can merge requirements
# from multiple blocks.
_accumulate(spec)
continue
pkg_name = str(pkg).strip()
if not pkg_name:
raise ValueError("Requirement package name cannot be empty.")
spec_str = "" if spec is None else str(spec).strip()
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
spec_str = f"=={spec_str}"
existing_spec = normalized.get(pkg_name)
if existing_spec is not None:
if not existing_spec and spec_str:
normalized[pkg_name] = spec_str
elif existing_spec and spec_str and existing_spec != spec_str:
try:
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
except InvalidSpecifier:
logger.warning(
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
)
else:
normalized[pkg_name] = str(combined_spec)
continue
normalized[pkg_name] = spec_str
_accumulate(reqs)
return normalized

View File

@@ -129,7 +129,7 @@ else:
"AnimateDiffVideoToVideoControlNetPipeline",
]
_import_structure["bria"] = ["BriaPipeline"]
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
_import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"]
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"]
_import_structure["flux"] = [
"FluxControlPipeline",
@@ -597,7 +597,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .bria import BriaPipeline
from .bria_fibo import BriaFiboPipeline
from .bria_fibo import BriaFiboEditPipeline, BriaFiboPipeline
from .chroma import ChromaImg2ImgPipeline, ChromaInpaintPipeline, ChromaPipeline
from .chronoedit import ChronoEditPipeline
from .cogvideo import (

View File

@@ -23,6 +23,8 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"]
_import_structure["pipeline_bria_fibo_edit"] = ["BriaFiboEditPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -33,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_bria_fibo import BriaFiboPipeline
from .pipeline_bria_fibo_edit import BriaFiboEditPipeline
else:
import sys

File diff suppressed because it is too large Load Diff

View File

@@ -84,7 +84,6 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
>>> from diffusers.utils import load_image
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
>>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
>>> controlnet = ControlNetModel.from_pretrained(

View File

@@ -53,7 +53,6 @@ EXAMPLE_DOC_STRING = """
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> from diffusers import HiDreamImagePipeline
>>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
... "meta-llama/Meta-Llama-3.1-8B-Instruct",

View File

@@ -85,7 +85,6 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetPAGImg2ImgPipeline, AutoencoderKL
>>> from diffusers.utils import load_image
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
>>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
>>> controlnet = ControlNetModel.from_pretrained(

View File

@@ -459,7 +459,6 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
>>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline
>>> import torch
>>> pipeline = StableDiffusionPipeline.from_pretrained(
... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
... )

View File

@@ -587,6 +587,21 @@ class AuraFlowPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class BriaFiboEditPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class BriaFiboPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -1,4 +1,6 @@
import gc
import json
import os
import tempfile
from typing import Callable, Union
@@ -8,9 +10,16 @@ import torch
import diffusers
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.utils import logging
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
from ..testing_utils import (
CaptureLogger,
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
torch_device,
)
class ModularPipelineTesterMixin:
@@ -335,3 +344,53 @@ class ModularGuiderTesterMixin:
assert out_cfg.shape == out_no_cfg.shape
max_diff = torch.abs(out_cfg - out_no_cfg).max()
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def test_custom_requirements_save_load(self):
pipe = self.get_dummy_block_pipe()
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
config_path = os.path.join(tmpdir, "modular_config.json")
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_warnings(self):
pipe = self.get_dummy_block_pipe()
with tempfile.TemporaryDirectory() as tmpdir:
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(tmpdir)
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)

View File

@@ -0,0 +1,192 @@
# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM
from diffusers import (
AutoencoderKLWan,
BriaFiboEditPipeline,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
from tests.pipelines.test_pipelines_common import PipelineTesterMixin
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
enable_full_determinism()
class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = BriaFiboEditPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
test_layerwise_casting = False
test_group_offloading = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = BriaFiboTransformer2DModel(
patch_size=1,
in_channels=16,
num_layers=1,
num_single_layers=1,
attention_head_dim=8,
num_attention_heads=2,
joint_attention_dim=64,
text_encoder_dim=32,
pooled_projection_dim=None,
axes_dims_rope=[0, 4, 4],
)
vae = AutoencoderKLWan(
base_dim=80,
decoder_base_dim=128,
dim_mult=[1, 2, 4, 4],
dropout=0.0,
in_channels=12,
latents_mean=[0.0] * 16,
latents_std=[1.0] * 16,
is_residual=True,
num_res_blocks=2,
out_channels=12,
patch_size=2,
scale_factor_spatial=16,
scale_factor_temporal=4,
temperal_downsample=[False, True, True],
z_dim=16,
)
scheduler = FlowMatchEulerDiscreteScheduler()
text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32))
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": '{"text": "A painting of a squirrel eating a burger","edit_instruction": "A painting of a squirrel eating a burger"}',
"negative_prompt": "bad, ugly",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 192,
"width": 336,
"output_type": "np",
}
image = Image.new("RGB", (336, 192), (255, 255, 255))
inputs["image"] = image
return inputs
@unittest.skip(reason="will not be supported due to dim-fusion")
def test_encode_prompt_works_in_isolation(self):
pass
@unittest.skip(reason="Batching is not supported yet")
def test_num_images_per_prompt(self):
pass
@unittest.skip(reason="Batching is not supported yet")
def test_inference_batch_consistent(self):
pass
@unittest.skip(reason="Batching is not supported yet")
def test_inference_batch_single_identical(self):
pass
def test_bria_fibo_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = {"edit_instruction": "a different prompt"}
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
assert max_diff > 1e-6
def test_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (64, 64), (32, 64)]
for height, width in height_width_pairs:
expected_height = height
expected_width = width
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
def test_bria_fibo_edit_mask(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
mask = Image.fromarray((np.ones((192, 336)) * 255).astype(np.uint8), mode="L")
inputs.update({"mask": mask})
output = pipe(**inputs).images[0]
assert output.shape == (192, 336, 3)
def test_bria_fibo_edit_mask_image_size_mismatch(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
mask = Image.fromarray((np.ones((64, 64)) * 255).astype(np.uint8), mode="L")
inputs.update({"mask": mask})
with self.assertRaisesRegex(ValueError, "Mask and image must have the same size"):
pipe(**inputs)
def test_bria_fibo_edit_mask_no_image(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
mask = Image.fromarray((np.ones((32, 32)) * 255).astype(np.uint8), mode="L")
# Remove image from inputs if it's there (it shouldn't be by default from get_dummy_inputs)
inputs.pop("image", None)
inputs.update({"mask": mask})
with self.assertRaisesRegex(ValueError, "If mask is provided, image must also be provided"):
pipe(**inputs)