Compare commits

...

6 Commits

Author SHA1 Message Date
Sayak Paul
835a087a47 Merge branch 'main' into apply-lora-scale-decorator 2026-01-20 10:44:21 +05:30
Gal Davidi
d7a1c31f4f Fibo edit pipeline (#12930)
* Feature: Add BriaFiboEditPipeline to diffusers

* Introduced BriaFiboEditPipeline class with necessary backend requirements.
* Updated import structures in relevant modules to include BriaFiboEditPipeline.
* Ensured compatibility with existing pipelines and type checking.

* Feature: Introduce Bria Fibo Edit Pipeline

* Added BriaFiboEditPipeline class for structured JSON-native image editing.
* Created documentation for the new pipeline in bria_fibo_edit.md.
* Updated import structures to include the new pipeline and its components.
* Added unit tests for the BriaFiboEditPipeline to ensure functionality and correctness.

* Enhancement: Update Bria Fibo Edit Pipeline and Documentation

* Refined the Bria Fibo Edit model description for clarity and detail.
* Added usage instructions for model authentication and login.
* Implemented mask handling functions in the BriaFiboEditPipeline for improved image editing capabilities.
* Updated unit tests to cover new mask functionalities and ensure input validation.
* Adjusted example code in documentation to reflect changes in the pipeline's usage.

* Update Bria Fibo Edit documentation with corrected Hugging Face page link

* add dreambooth training script

* style and quality

* Delete temp.py

* Enhancement: Improve JSON caption validation in DreamBoothDataset

* Updated the clean_json_caption function to handle both string and dictionary inputs for captions.
* Added error handling to raise a ValueError for invalid caption types, ensuring better input validation.

* Add datasets dependency to requirements_fibo_edit.txt

* Add bria_fibo_edit to docs table of contents

* Fix dummy objects ordering

* Fix BriaFiboEditPipeline to use passed generator parameter

The pipeline was ignoring the generator parameter and only using
the seed parameter. This caused non-deterministic outputs in tests
that pass a seeded generator.

* Remove fibo_edit training script and related files

---------

Co-authored-by: kfirbria <kfir@bria.ai>
2026-01-19 22:09:53 +05:30
Sayak Paul
29b15f41c7 [chore] make style to push new changes. (#12998)
make style to push new changes.
2026-01-19 16:02:13 +05:30
sayakpaul
75edff93a0 Revert "make style && make quality"
This reverts commit 76f51a5e92.
2026-01-19 15:35:20 +05:30
sayakpaul
76f51a5e92 make style && make quality 2026-01-19 15:34:29 +05:30
sayakpaul
afa4a23c6c feat: implement apply_lora_scale to remove boilerplate. 2026-01-19 10:04:24 +05:30
16 changed files with 1439 additions and 25 deletions

View File

@@ -496,6 +496,8 @@
title: Bria 3.2
- local: api/pipelines/bria_fibo
title: Bria Fibo
- local: api/pipelines/bria_fibo_edit
title: Bria Fibo Edit
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogview3

View File

@@ -0,0 +1,33 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Bria Fibo Edit
Fibo Edit is an 8B parameter image-to-image model that introduces a new paradigm of structured control, operating on JSON inputs paired with source images to enable deterministic and repeatable editing workflows.
Featuring native masking for granular precision, it moves beyond simple prompt-based diffusion to offer explicit, interpretable control optimized for production environments.
Its lightweight architecture is designed for deep customization, empowering researchers to build specialized "Edit" models for domain-specific tasks while delivering top-tier aesthetic quality
## Usage
_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/Fibo-Edit), fill in the form and accept the gate. Once you are in, you need to login so that your system knows youve accepted the gate._
Use the command below to log in:
```bash
hf auth login
```
## BriaFiboEditPipeline
[[autodoc]] BriaFiboEditPipeline
- all
- __call__

View File

@@ -457,6 +457,7 @@ else:
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"BriaFiboEditPipeline",
"BriaFiboPipeline",
"BriaPipeline",
"ChromaImg2ImgPipeline",
@@ -1185,6 +1186,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
BriaFiboEditPipeline,
BriaFiboPipeline,
BriaPipeline,
ChromaImg2ImgPipeline,

View File

@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -634,6 +634,7 @@ class FluxTransformer2DModel(
self.gradient_checkpointing = False
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -675,20 +676,6 @@ class FluxTransformer2DModel(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
@@ -785,10 +772,6 @@ class FluxTransformer2DModel(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -129,7 +129,7 @@ else:
"AnimateDiffVideoToVideoControlNetPipeline",
]
_import_structure["bria"] = ["BriaPipeline"]
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
_import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"]
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"]
_import_structure["flux"] = [
"FluxControlPipeline",
@@ -597,7 +597,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .bria import BriaPipeline
from .bria_fibo import BriaFiboPipeline
from .bria_fibo import BriaFiboEditPipeline, BriaFiboPipeline
from .chroma import ChromaImg2ImgPipeline, ChromaInpaintPipeline, ChromaPipeline
from .chronoedit import ChronoEditPipeline
from .cogvideo import (

View File

@@ -23,6 +23,8 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"]
_import_structure["pipeline_bria_fibo_edit"] = ["BriaFiboEditPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -33,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_bria_fibo import BriaFiboPipeline
from .pipeline_bria_fibo_edit import BriaFiboEditPipeline
else:
import sys

File diff suppressed because it is too large Load Diff

View File

@@ -84,7 +84,6 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
>>> from diffusers.utils import load_image
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
>>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
>>> controlnet = ControlNetModel.from_pretrained(

View File

@@ -53,7 +53,6 @@ EXAMPLE_DOC_STRING = """
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> from diffusers import HiDreamImagePipeline
>>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
... "meta-llama/Meta-Llama-3.1-8B-Instruct",

View File

@@ -85,7 +85,6 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetPAGImg2ImgPipeline, AutoencoderKL
>>> from diffusers.utils import load_image
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
>>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
>>> controlnet = ControlNetModel.from_pretrained(

View File

@@ -459,7 +459,6 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
>>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline
>>> import torch
>>> pipeline = StableDiffusionPipeline.from_pretrained(
... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
... )

View File

@@ -130,6 +130,7 @@ from .loading_utils import get_module_from_name, get_submodule_by_name, load_ima
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
apply_lora_scale,
check_peft_version,
delete_adapter_layers,
get_adapter_name,

View File

@@ -587,6 +587,21 @@ class AuraFlowPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class BriaFiboEditPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class BriaFiboPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -16,6 +16,7 @@ PEFT utilities: Utilities related to peft library
"""
import collections
import functools
import importlib
from typing import Optional
@@ -275,6 +276,59 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
module.set_scale(adapter_name, get_module_weight(weight, module_name))
def apply_lora_scale(kwargs_name: str = "joint_attention_kwargs"):
"""
Decorator to automatically handle LoRA layer scaling/unscaling in forward methods.
This decorator extracts the `lora_scale` from the specified kwargs parameter, applies scaling before the forward
pass, and ensures unscaling happens after, even if an exception occurs.
Args:
kwargs_name (`str`, defaults to `"joint_attention_kwargs"`):
The name of the keyword argument that contains the LoRA scale. Common values include
"joint_attention_kwargs", "attention_kwargs", "cross_attention_kwargs", etc.
"""
def decorator(forward_fn):
@functools.wraps(forward_fn)
def wrapper(self, *args, **kwargs):
from . import USE_PEFT_BACKEND
lora_scale = 1.0
attention_kwargs = kwargs.get(kwargs_name)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
kwargs[kwargs_name] = attention_kwargs
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
if (
not USE_PEFT_BACKEND
and attention_kwargs is not None
and attention_kwargs.get("scale", None) is not None
):
logger.warning(
f"Passing `scale` via `{kwargs_name}` when not using the PEFT backend is ineffective."
)
# Apply LoRA scaling if using PEFT backend
if USE_PEFT_BACKEND:
scale_lora_layers(self, lora_scale)
try:
# Execute the forward pass
result = forward_fn(self, *args, **kwargs)
return result
finally:
# Always unscale, even if forward pass raises an exception
if USE_PEFT_BACKEND:
unscale_lora_layers(self, lora_scale)
return wrapper
return decorator
def check_peft_version(min_version: str) -> None:
r"""
Checks if the version of PEFT is compatible.

View File

@@ -0,0 +1,192 @@
# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM
from diffusers import (
AutoencoderKLWan,
BriaFiboEditPipeline,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
from tests.pipelines.test_pipelines_common import PipelineTesterMixin
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
enable_full_determinism()
class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = BriaFiboEditPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
test_layerwise_casting = False
test_group_offloading = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = BriaFiboTransformer2DModel(
patch_size=1,
in_channels=16,
num_layers=1,
num_single_layers=1,
attention_head_dim=8,
num_attention_heads=2,
joint_attention_dim=64,
text_encoder_dim=32,
pooled_projection_dim=None,
axes_dims_rope=[0, 4, 4],
)
vae = AutoencoderKLWan(
base_dim=80,
decoder_base_dim=128,
dim_mult=[1, 2, 4, 4],
dropout=0.0,
in_channels=12,
latents_mean=[0.0] * 16,
latents_std=[1.0] * 16,
is_residual=True,
num_res_blocks=2,
out_channels=12,
patch_size=2,
scale_factor_spatial=16,
scale_factor_temporal=4,
temperal_downsample=[False, True, True],
z_dim=16,
)
scheduler = FlowMatchEulerDiscreteScheduler()
text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32))
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": '{"text": "A painting of a squirrel eating a burger","edit_instruction": "A painting of a squirrel eating a burger"}',
"negative_prompt": "bad, ugly",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 192,
"width": 336,
"output_type": "np",
}
image = Image.new("RGB", (336, 192), (255, 255, 255))
inputs["image"] = image
return inputs
@unittest.skip(reason="will not be supported due to dim-fusion")
def test_encode_prompt_works_in_isolation(self):
pass
@unittest.skip(reason="Batching is not supported yet")
def test_num_images_per_prompt(self):
pass
@unittest.skip(reason="Batching is not supported yet")
def test_inference_batch_consistent(self):
pass
@unittest.skip(reason="Batching is not supported yet")
def test_inference_batch_single_identical(self):
pass
def test_bria_fibo_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = {"edit_instruction": "a different prompt"}
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
assert max_diff > 1e-6
def test_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (64, 64), (32, 64)]
for height, width in height_width_pairs:
expected_height = height
expected_width = width
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
def test_bria_fibo_edit_mask(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
mask = Image.fromarray((np.ones((192, 336)) * 255).astype(np.uint8), mode="L")
inputs.update({"mask": mask})
output = pipe(**inputs).images[0]
assert output.shape == (192, 336, 3)
def test_bria_fibo_edit_mask_image_size_mismatch(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
mask = Image.fromarray((np.ones((64, 64)) * 255).astype(np.uint8), mode="L")
inputs.update({"mask": mask})
with self.assertRaisesRegex(ValueError, "Mask and image must have the same size"):
pipe(**inputs)
def test_bria_fibo_edit_mask_no_image(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
mask = Image.fromarray((np.ones((32, 32)) * 255).astype(np.uint8), mode="L")
# Remove image from inputs if it's there (it shouldn't be by default from get_dummy_inputs)
inputs.pop("image", None)
inputs.update({"mask": mask})
with self.assertRaisesRegex(ValueError, "If mask is provided, image must also be provided"):
pipe(**inputs)