mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-20 18:45:47 +08:00
Compare commits
6 Commits
enable-cp-
...
apply-lora
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
835a087a47 | ||
|
|
d7a1c31f4f | ||
|
|
29b15f41c7 | ||
|
|
75edff93a0 | ||
|
|
76f51a5e92 | ||
|
|
afa4a23c6c |
@@ -496,6 +496,8 @@
|
||||
title: Bria 3.2
|
||||
- local: api/pipelines/bria_fibo
|
||||
title: Bria Fibo
|
||||
- local: api/pipelines/bria_fibo_edit
|
||||
title: Bria Fibo Edit
|
||||
- local: api/pipelines/chroma
|
||||
title: Chroma
|
||||
- local: api/pipelines/cogview3
|
||||
|
||||
33
docs/source/en/api/pipelines/bria_fibo_edit.md
Normal file
33
docs/source/en/api/pipelines/bria_fibo_edit.md
Normal file
@@ -0,0 +1,33 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Bria Fibo Edit
|
||||
|
||||
Fibo Edit is an 8B parameter image-to-image model that introduces a new paradigm of structured control, operating on JSON inputs paired with source images to enable deterministic and repeatable editing workflows.
|
||||
Featuring native masking for granular precision, it moves beyond simple prompt-based diffusion to offer explicit, interpretable control optimized for production environments.
|
||||
Its lightweight architecture is designed for deep customization, empowering researchers to build specialized "Edit" models for domain-specific tasks while delivering top-tier aesthetic quality
|
||||
|
||||
## Usage
|
||||
_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/Fibo-Edit), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
|
||||
|
||||
Use the command below to log in:
|
||||
|
||||
```bash
|
||||
hf auth login
|
||||
```
|
||||
|
||||
|
||||
## BriaFiboEditPipeline
|
||||
|
||||
[[autodoc]] BriaFiboEditPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -457,6 +457,7 @@ else:
|
||||
"AuraFlowPipeline",
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"BlipDiffusionPipeline",
|
||||
"BriaFiboEditPipeline",
|
||||
"BriaFiboPipeline",
|
||||
"BriaPipeline",
|
||||
"ChromaImg2ImgPipeline",
|
||||
@@ -1185,6 +1186,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AudioLDMPipeline,
|
||||
AuraFlowPipeline,
|
||||
BriaFiboEditPipeline,
|
||||
BriaFiboPipeline,
|
||||
BriaPipeline,
|
||||
ChromaImg2ImgPipeline,
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
@@ -634,6 +634,7 @@ class FluxTransformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -675,20 +676,6 @@ class FluxTransformer2DModel(
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
@@ -785,10 +772,6 @@ class FluxTransformer2DModel(
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -129,7 +129,7 @@ else:
|
||||
"AnimateDiffVideoToVideoControlNetPipeline",
|
||||
]
|
||||
_import_structure["bria"] = ["BriaPipeline"]
|
||||
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
|
||||
_import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"]
|
||||
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxControlPipeline",
|
||||
@@ -597,7 +597,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .aura_flow import AuraFlowPipeline
|
||||
from .blip_diffusion import BlipDiffusionPipeline
|
||||
from .bria import BriaPipeline
|
||||
from .bria_fibo import BriaFiboPipeline
|
||||
from .bria_fibo import BriaFiboEditPipeline, BriaFiboPipeline
|
||||
from .chroma import ChromaImg2ImgPipeline, ChromaInpaintPipeline, ChromaPipeline
|
||||
from .chronoedit import ChronoEditPipeline
|
||||
from .cogvideo import (
|
||||
|
||||
@@ -23,6 +23,8 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"]
|
||||
_import_structure["pipeline_bria_fibo_edit"] = ["BriaFiboEditPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -33,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_bria_fibo import BriaFiboPipeline
|
||||
from .pipeline_bria_fibo_edit import BriaFiboEditPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
1133
src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py
Normal file
1133
src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -84,7 +84,6 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
|
||||
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
|
||||
>>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
>>> controlnet = ControlNetModel.from_pretrained(
|
||||
|
||||
@@ -53,7 +53,6 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
>>> from diffusers import HiDreamImagePipeline
|
||||
|
||||
|
||||
>>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
|
||||
... "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
|
||||
@@ -85,7 +85,6 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetPAGImg2ImgPipeline, AutoencoderKL
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
|
||||
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
|
||||
>>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
>>> controlnet = ControlNetModel.from_pretrained(
|
||||
|
||||
@@ -459,7 +459,6 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
||||
>>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline
|
||||
>>> import torch
|
||||
|
||||
|
||||
>>> pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
|
||||
... )
|
||||
|
||||
@@ -130,6 +130,7 @@ from .loading_utils import get_module_from_name, get_submodule_by_name, load_ima
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .peft_utils import (
|
||||
apply_lora_scale,
|
||||
check_peft_version,
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
|
||||
@@ -587,6 +587,21 @@ class AuraFlowPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class BriaFiboEditPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class BriaFiboPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ PEFT utilities: Utilities related to peft library
|
||||
"""
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
@@ -275,6 +276,59 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
|
||||
module.set_scale(adapter_name, get_module_weight(weight, module_name))
|
||||
|
||||
|
||||
def apply_lora_scale(kwargs_name: str = "joint_attention_kwargs"):
|
||||
"""
|
||||
Decorator to automatically handle LoRA layer scaling/unscaling in forward methods.
|
||||
|
||||
This decorator extracts the `lora_scale` from the specified kwargs parameter, applies scaling before the forward
|
||||
pass, and ensures unscaling happens after, even if an exception occurs.
|
||||
|
||||
Args:
|
||||
kwargs_name (`str`, defaults to `"joint_attention_kwargs"`):
|
||||
The name of the keyword argument that contains the LoRA scale. Common values include
|
||||
"joint_attention_kwargs", "attention_kwargs", "cross_attention_kwargs", etc.
|
||||
"""
|
||||
|
||||
def decorator(forward_fn):
|
||||
@functools.wraps(forward_fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
from . import USE_PEFT_BACKEND
|
||||
|
||||
lora_scale = 1.0
|
||||
attention_kwargs = kwargs.get(kwargs_name)
|
||||
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
kwargs[kwargs_name] = attention_kwargs
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
if (
|
||||
not USE_PEFT_BACKEND
|
||||
and attention_kwargs is not None
|
||||
and attention_kwargs.get("scale", None) is not None
|
||||
):
|
||||
logger.warning(
|
||||
f"Passing `scale` via `{kwargs_name}` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# Apply LoRA scaling if using PEFT backend
|
||||
if USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self, lora_scale)
|
||||
|
||||
try:
|
||||
# Execute the forward pass
|
||||
result = forward_fn(self, *args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
# Always unscale, even if forward pass raises an exception
|
||||
if USE_PEFT_BACKEND:
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check_peft_version(min_version: str) -> None:
|
||||
r"""
|
||||
Checks if the version of PEFT is compatible.
|
||||
|
||||
0
tests/pipelines/bria_fibo_edit/__init__.py
Normal file
0
tests/pipelines/bria_fibo_edit/__init__.py
Normal file
192
tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py
Normal file
192
tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
BriaFiboEditPipeline,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
|
||||
from tests.pipelines.test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = BriaFiboEditPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = False
|
||||
test_group_offloading = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = BriaFiboTransformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=16,
|
||||
num_layers=1,
|
||||
num_single_layers=1,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=2,
|
||||
joint_attention_dim=64,
|
||||
text_encoder_dim=32,
|
||||
pooled_projection_dim=None,
|
||||
axes_dims_rope=[0, 4, 4],
|
||||
)
|
||||
|
||||
vae = AutoencoderKLWan(
|
||||
base_dim=80,
|
||||
decoder_base_dim=128,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
dropout=0.0,
|
||||
in_channels=12,
|
||||
latents_mean=[0.0] * 16,
|
||||
latents_std=[1.0] * 16,
|
||||
is_residual=True,
|
||||
num_res_blocks=2,
|
||||
out_channels=12,
|
||||
patch_size=2,
|
||||
scale_factor_spatial=16,
|
||||
scale_factor_temporal=4,
|
||||
temperal_downsample=[False, True, True],
|
||||
z_dim=16,
|
||||
)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32))
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
components = {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": '{"text": "A painting of a squirrel eating a burger","edit_instruction": "A painting of a squirrel eating a burger"}',
|
||||
"negative_prompt": "bad, ugly",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"height": 192,
|
||||
"width": 336,
|
||||
"output_type": "np",
|
||||
}
|
||||
image = Image.new("RGB", (336, 192), (255, 255, 255))
|
||||
inputs["image"] = image
|
||||
return inputs
|
||||
|
||||
@unittest.skip(reason="will not be supported due to dim-fusion")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Batching is not supported yet")
|
||||
def test_num_images_per_prompt(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Batching is not supported yet")
|
||||
def test_inference_batch_consistent(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Batching is not supported yet")
|
||||
def test_inference_batch_single_identical(self):
|
||||
pass
|
||||
|
||||
def test_bria_fibo_different_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe = pipe.to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt"] = {"edit_instruction": "a different prompt"}
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
assert max_diff > 1e-6
|
||||
|
||||
def test_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe = pipe.to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
height_width_pairs = [(32, 32), (64, 64), (32, 64)]
|
||||
for height, width in height_width_pairs:
|
||||
expected_height = height
|
||||
expected_width = width
|
||||
|
||||
inputs.update({"height": height, "width": width})
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
assert (output_height, output_width) == (expected_height, expected_width)
|
||||
|
||||
def test_bria_fibo_edit_mask(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe = pipe.to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
mask = Image.fromarray((np.ones((192, 336)) * 255).astype(np.uint8), mode="L")
|
||||
|
||||
inputs.update({"mask": mask})
|
||||
output = pipe(**inputs).images[0]
|
||||
|
||||
assert output.shape == (192, 336, 3)
|
||||
|
||||
def test_bria_fibo_edit_mask_image_size_mismatch(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe = pipe.to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
mask = Image.fromarray((np.ones((64, 64)) * 255).astype(np.uint8), mode="L")
|
||||
|
||||
inputs.update({"mask": mask})
|
||||
with self.assertRaisesRegex(ValueError, "Mask and image must have the same size"):
|
||||
pipe(**inputs)
|
||||
|
||||
def test_bria_fibo_edit_mask_no_image(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components())
|
||||
pipe = pipe.to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
mask = Image.fromarray((np.ones((32, 32)) * 255).astype(np.uint8), mode="L")
|
||||
|
||||
# Remove image from inputs if it's there (it shouldn't be by default from get_dummy_inputs)
|
||||
inputs.pop("image", None)
|
||||
inputs.update({"mask": mask})
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "If mask is provided, image must also be provided"):
|
||||
pipe(**inputs)
|
||||
Reference in New Issue
Block a user