Compare commits

..

6 Commits

Author SHA1 Message Date
Sayak Paul
fe4c0be8a6 Merge branch 'main' into device-map-direct 2026-01-19 10:26:24 +05:30
Sayak Paul
b28d6d45fa Merge branch 'main' into device-map-direct 2026-01-13 10:35:08 +05:30
Sayak Paul
3b334de68a Merge branch 'main' into device-map-direct 2026-01-08 12:23:39 +05:30
Sayak Paul
c61e455ce7 Merge branch 'main' into device-map-direct 2025-12-23 13:16:10 +05:30
Sayak Paul
6f5eb0a933 Merge branch 'main' into device-map-direct 2025-12-11 14:47:09 +08:00
sayakpaul
83ec2fb793 support device type device_maps to work with offloading. 2025-12-09 11:10:41 +05:30
24 changed files with 67 additions and 2142 deletions

View File

@@ -496,8 +496,6 @@
title: Bria 3.2
- local: api/pipelines/bria_fibo
title: Bria Fibo
- local: api/pipelines/bria_fibo_edit
title: Bria Fibo Edit
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogview3

View File

@@ -1,33 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Bria Fibo Edit
Fibo Edit is an 8B parameter image-to-image model that introduces a new paradigm of structured control, operating on JSON inputs paired with source images to enable deterministic and repeatable editing workflows.
Featuring native masking for granular precision, it moves beyond simple prompt-based diffusion to offer explicit, interpretable control optimized for production environments.
Its lightweight architecture is designed for deep customization, empowering researchers to build specialized "Edit" models for domain-specific tasks while delivering top-tier aesthetic quality
## Usage
_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/Fibo-Edit), fill in the form and accept the gate. Once you are in, you need to login so that your system knows youve accepted the gate._
Use the command below to log in:
```bash
hf auth login
```
## BriaFiboEditPipeline
[[autodoc]] BriaFiboEditPipeline
- all
- __call__

View File

@@ -413,9 +413,6 @@ else:
_import_structure["modular_pipelines"].extend(
[
"Flux2AutoBlocks",
"Flux2KleinAutoBlocks",
"Flux2KleinBaseAutoBlocks",
"Flux2KleinModularPipeline",
"Flux2ModularPipeline",
"FluxAutoBlocks",
"FluxKontextAutoBlocks",
@@ -460,7 +457,6 @@ else:
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"BriaFiboEditPipeline",
"BriaFiboPipeline",
"BriaPipeline",
"ChromaImg2ImgPipeline",
@@ -1149,9 +1145,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .modular_pipelines import (
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
FluxAutoBlocks,
FluxKontextAutoBlocks,
@@ -1192,7 +1185,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
BriaFiboEditPipeline,
BriaFiboPipeline,
BriaPipeline,
ChromaImg2ImgPipeline,

View File

@@ -54,10 +54,7 @@ else:
]
_import_structure["flux2"] = [
"Flux2AutoBlocks",
"Flux2KleinAutoBlocks",
"Flux2KleinBaseAutoBlocks",
"Flux2ModularPipeline",
"Flux2KleinModularPipeline",
]
_import_structure["qwenimage"] = [
"QwenImageAutoBlocks",
@@ -84,13 +81,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .components_manager import ComponentsManager
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
from .flux2 import (
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
)
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,

View File

@@ -43,7 +43,7 @@ else:
"Flux2ProcessImagesInputStep",
"Flux2TextInputStep",
]
_import_structure["modular_blocks_flux2"] = [
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"REMOTE_AUTO_BLOCKS",
@@ -54,8 +54,7 @@ else:
"Flux2BeforeDenoiseStep",
"Flux2VaeEncoderSequentialStep",
]
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -86,7 +85,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2ProcessImagesInputStep,
Flux2TextInputStep,
)
from .modular_blocks_flux2 import (
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
IMAGE_CONDITIONED_BLOCKS,
@@ -97,11 +96,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2BeforeDenoiseStep,
Flux2VaeEncoderSequentialStep,
)
from .modular_blocks_flux2_klein import (
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
)
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
from .modular_pipeline import Flux2ModularPipeline
else:
import sys

View File

@@ -353,7 +353,7 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt_embeds", required=True),
InputParam(name="negative_prompt_embeds", required=False),
InputParam(name="latent_ids"),
]
@property
@@ -366,10 +366,10 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
),
OutputParam(
name="negative_txt_ids",
name="latent_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for negative text tokens, used for RoPE calculation.",
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",
),
]
@@ -399,11 +399,6 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
block_state.txt_ids = block_state.txt_ids.to(device)
block_state.negative_txt_ids = None
if block_state.negative_prompt_embeds is not None:
block_state.negative_txt_ids = self._prepare_text_ids(block_state.negative_prompt_embeds)
block_state.negative_txt_ids = block_state.negative_txt_ids.to(device)
self.set_block_state(state, block_state)
return components, state

View File

@@ -12,13 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, List, Tuple
import torch
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...models import Flux2Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging
@@ -28,8 +25,8 @@ from ..modular_pipeline import (
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
if is_torch_xla_available():
@@ -137,241 +134,6 @@ class Flux2LoopDenoiser(ModularPipelineBlocks):
return components, block_state
# same as Flux2LoopDenoiser but guidance=None
class Flux2KleinLoopDenoiser(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("transformer", Flux2Transformer2DModel)]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoises the latents for Flux2. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `Flux2DenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to denoise. Shape: (B, seq_len, C)",
),
InputParam(
"image_latents",
type_hint=torch.Tensor,
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
),
InputParam(
"image_latent_ids",
type_hint=torch.Tensor,
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Text embeddings from Qwen3",
),
InputParam(
"txt_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for text tokens (T, H, W, L)",
),
InputParam(
"latent_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for latent tokens (T, H, W, L)",
),
]
@torch.no_grad()
def __call__(
self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
latents = block_state.latents
latent_model_input = latents.to(components.transformer.dtype)
img_ids = block_state.latent_ids
image_latents = getattr(block_state, "image_latents", None)
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
image_latent_ids = block_state.image_latent_ids
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = components.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=block_state.prompt_embeds,
txt_ids=block_state.txt_ids,
img_ids=img_ids,
joint_attention_kwargs=block_state.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]
block_state.noise_pred = noise_pred
return components, block_state
# support CFG for Flux2-Klein base model
class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("transformer", Flux2Transformer2DModel),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 4.0}),
default_creation_method="from_config",
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
ConfigSpec(name="is_distilled", default=False),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoises the latents for Flux2. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `Flux2DenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to denoise. Shape: (B, seq_len, C)",
),
InputParam(
"image_latents",
type_hint=torch.Tensor,
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
),
InputParam(
"image_latent_ids",
type_hint=torch.Tensor,
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Text embeddings from Qwen3",
),
InputParam(
"negative_prompt_embeds",
required=False,
type_hint=torch.Tensor,
description="Negative text embeddings from Qwen3",
),
InputParam(
"txt_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for text tokens (T, H, W, L)",
),
InputParam(
"negative_txt_ids",
required=False,
type_hint=torch.Tensor,
description="4D position IDs for negative text tokens (T, H, W, L)",
),
InputParam(
"latent_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for latent tokens (T, H, W, L)",
),
InputParam(
kwargs_type="denoiser_input_fields",
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
),
]
@torch.no_grad()
def __call__(
self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
latents = block_state.latents
latent_model_input = latents.to(components.transformer.dtype)
img_ids = block_state.latent_ids
image_latents = getattr(block_state, "image_latents", None)
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
image_latent_ids = block_state.image_latent_ids
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
guider_inputs = {
"encoder_hidden_states": (
getattr(block_state, "prompt_embeds", None),
getattr(block_state, "negative_prompt_embeds", None),
),
"txt_ids": (
getattr(block_state, "txt_ids", None),
getattr(block_state, "negative_txt_ids", None),
),
}
transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
additional_cond_kwargs = {}
for field_name, field_value in block_state.denoiser_input_fields.items():
if field_name in transformer_args and field_name not in guider_inputs:
additional_cond_kwargs[field_name] = field_value
block_state.additional_cond_kwargs.update(additional_cond_kwargs)
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(guider_inputs)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
cond_kwargs.update(additional_cond_kwargs)
noise_pred = components.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=None,
img_ids=img_ids,
joint_attention_kwargs=block_state.joint_attention_kwargs,
return_dict=False,
**cond_kwargs,
)[0]
guider_state_batch.noise_pred = noise_pred[:, : latents.size(1)]
components.guider.cleanup_models(components.transformer)
# perform guidance
block_state.noise_pred = components.guider(guider_state)[0]
return components, block_state
class Flux2LoopAfterDenoiser(ModularPipelineBlocks):
model_name = "flux2"
@@ -458,8 +220,6 @@ class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks):
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
)
block_state.additional_cond_kwargs = {}
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
@@ -490,35 +250,3 @@ class Flux2DenoiseStep(Flux2DenoiseLoopWrapper):
" - `Flux2LoopAfterDenoiser`\n"
"This block supports both text-to-image and image-conditioned generation."
)
class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper):
block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser]
block_names = ["denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the latents for Flux2. \n"
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `Flux2KleinLoopDenoiser`\n"
" - `Flux2LoopAfterDenoiser`\n"
"This block supports both text-to-image and image-conditioned generation."
)
class Flux2KleinBaseDenoiseStep(Flux2DenoiseLoopWrapper):
block_classes = [Flux2KleinBaseLoopDenoiser, Flux2LoopAfterDenoiser]
block_names = ["denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the latents for Flux2. \n"
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `Flux2KleinBaseLoopDenoiser`\n"
" - `Flux2LoopAfterDenoiser`\n"
"This block supports both text-to-image and image-conditioned generation."
)

View File

@@ -15,13 +15,13 @@
from typing import List, Optional, Tuple, Union
import torch
from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
from ...models import AutoencoderKLFlux2
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -79,8 +79,10 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
InputParam("joint_attention_kwargs"),
]
@property
@@ -97,7 +99,14 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
prompt_embeds = getattr(block_state, "prompt_embeds", None)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
"Please make sure to only forward one of the two."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod
@@ -156,6 +165,10 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
block_state.device = components._execution_device
if block_state.prompt_embeds is not None:
self.set_block_state(state, block_state)
return components, state
prompt = block_state.prompt
if prompt is None:
prompt = ""
@@ -192,6 +205,7 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
]
@property
@@ -208,8 +222,15 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
prompt_embeds = getattr(block_state, "prompt_embeds", None)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
"Please make sure to only forward one of the two."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
@@ -223,6 +244,10 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
block_state.device = components._execution_device
if block_state.prompt_embeds is not None:
self.set_block_state(state, block_state)
return components, state
prompt = block_state.prompt
if prompt is None:
prompt = ""
@@ -245,153 +270,6 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
return components, state
class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def description(self) -> str:
return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", Qwen3ForCausalLM),
ComponentSpec("tokenizer", Qwen2TokenizerFast),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
ConfigSpec(name="is_distilled", default=False),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Text embeddings from qwen3 used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
description="Negative text embeddings from qwen3 used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds
def _get_qwen3_prompt_embeds(
text_encoder: Qwen3ForCausalLM,
tokenizer: Qwen2TokenizerFast,
prompt: Union[str, List[str]],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
hidden_states_layers: List[int] = (9, 18, 27),
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
all_input_ids = []
all_attention_masks = []
for single_prompt in prompt:
messages = [{"role": "user", "content": single_prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
all_input_ids.append(inputs["input_ids"])
all_attention_masks.append(inputs["attention_mask"])
input_ids = torch.cat(all_input_ids, dim=0).to(device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
# Forward pass through the model
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
@torch.no_grad()
def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(block_state)
device = components._execution_device
prompt = block_state.prompt
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
block_state.prompt_embeds = self._get_qwen3_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=prompt,
device=device,
max_sequence_length=block_state.max_sequence_length,
hidden_states_layers=block_state.text_encoder_out_layers,
)
if components.requires_unconditional_embeds:
negative_prompt = ""
block_state.negative_prompt_embeds = self._get_qwen3_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=negative_prompt,
device=device,
max_sequence_length=block_state.max_sequence_length,
hidden_states_layers=block_state.text_encoder_out_layers,
)
else:
block_state.negative_prompt_embeds = None
self.set_block_state(state, block_state)
return components, state
class Flux2VaeEncoderStep(ModularPipelineBlocks):
model_name = "flux2"

View File

@@ -47,14 +47,7 @@ class Flux2TextInputStep(ModularPipelineBlocks):
required=True,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
),
InputParam(
"negative_prompt_embeds",
required=False,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.",
),
]
@@ -77,12 +70,6 @@ class Flux2TextInputStep(ModularPipelineBlocks):
kwargs_type="denoiser_input_fields",
description="Text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Negative text embeddings used to guide the image generation",
),
]
@torch.no_grad()
@@ -98,15 +85,6 @@ class Flux2TextInputStep(ModularPipelineBlocks):
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
if block_state.negative_prompt_embeds is not None:
_, seq_len, _ = block_state.negative_prompt_embeds.shape
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
1, block_state.num_images_per_prompt, 1
)
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -1,171 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
Flux2PrepareImageLatentsStep,
Flux2PrepareLatentsStep,
Flux2RoPEInputsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep
from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep
from .encoders import (
Flux2KleinTextEncoderStep,
Flux2VaeEncoderStep,
)
from .inputs import (
Flux2ProcessImagesInputStep,
Flux2TextInputStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Flux2KleinVaeEncoderBlocks = InsertableDict(
[
("preprocess", Flux2ProcessImagesInputStep()),
("encode", Flux2VaeEncoderStep()),
]
)
class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks):
model_name = "flux2"
block_classes = Flux2KleinVaeEncoderBlocks.values()
block_names = Flux2KleinVaeEncoderBlocks.keys()
@property
def description(self) -> str:
return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations."
class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [Flux2KleinVaeEncoderSequentialStep]
block_names = ["img_conditioning"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"VAE encoder step that encodes the image inputs into their latent representations.\n"
"This is an auto pipeline block that works for image conditioning tasks.\n"
" - `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided.\n"
" - If `image` is not provided, step will be skipped."
)
Flux2KleinCoreDenoiseBlocks = InsertableDict(
[
("input", Flux2TextInputStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2KleinDenoiseStep()),
]
)
class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = Flux2KleinCoreDenoiseBlocks.values()
block_names = Flux2KleinCoreDenoiseBlocks.keys()
@property
def description(self):
return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model)."
return (
"Core denoise step that performs the denoising process for Flux2-Klein.\n"
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs for the denoising step.\n"
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
" - `Flux2KleinDenoiseStep` (denoise) iteratively denoises the latents.\n"
)
Flux2KleinBaseCoreDenoiseBlocks = InsertableDict(
[
("input", Flux2TextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2KleinBaseDenoiseStep()),
]
)
class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = Flux2KleinBaseCoreDenoiseBlocks.values()
block_names = Flux2KleinBaseCoreDenoiseBlocks.keys()
@property
def description(self):
return "Core denoise step that performs the denoising process for Flux2-Klein (base model)."
return (
"Core denoise step that performs the denoising process for Flux2-Klein (base model).\n"
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs for the denoising step.\n"
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
" - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n"
)
class Flux2KleinAutoBlocks(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = [
Flux2KleinTextEncoderStep(),
Flux2KleinAutoVaeEncoderStep(),
Flux2KleinCoreDenoiseStep(),
Flux2DecodeStep(),
]
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
@property
def description(self):
return (
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein.\n"
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
)
class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = [
Flux2KleinTextEncoderStep(),
Flux2KleinAutoVaeEncoderStep(),
Flux2KleinBaseCoreDenoiseStep(),
Flux2DecodeStep(),
]
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
@property
def description(self):
return (
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model).\n"
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
)

View File

@@ -13,8 +13,6 @@
# limitations under the License.
from typing import Any, Dict, Optional
from ...loaders import Flux2LoraLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
@@ -57,56 +55,3 @@ class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
"""
A ModularPipeline for Flux2-Klein.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Flux2KleinBaseAutoBlocks"
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]:
return "Flux2KleinAutoBlocks"
else:
return "Flux2KleinBaseAutoBlocks"
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_sample_size(self):
return 128
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if getattr(self, "vae", None) is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
@property
def num_channels_latents(self):
num_channels_latents = 32
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
@property
def requires_unconditional_embeds(self):
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
return False
requires_unconditional_embeds = False
if hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
return requires_unconditional_embeds

View File

@@ -59,7 +59,6 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
("flux", "FluxModularPipeline"),
("flux-kontext", "FluxKontextModularPipeline"),
("flux2", "Flux2ModularPipeline"),
("flux2-klein", "Flux2KleinModularPipeline"),
("qwenimage", "QwenImageModularPipeline"),
("qwenimage-edit", "QwenImageEditModularPipeline"),
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),

View File

@@ -129,7 +129,7 @@ else:
"AnimateDiffVideoToVideoControlNetPipeline",
]
_import_structure["bria"] = ["BriaPipeline"]
_import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"]
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"]
_import_structure["flux"] = [
"FluxControlPipeline",
@@ -597,7 +597,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .bria import BriaPipeline
from .bria_fibo import BriaFiboEditPipeline, BriaFiboPipeline
from .bria_fibo import BriaFiboPipeline
from .chroma import ChromaImg2ImgPipeline, ChromaInpaintPipeline, ChromaPipeline
from .chronoedit import ChronoEditPipeline
from .cogvideo import (

View File

@@ -23,8 +23,6 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"]
_import_structure["pipeline_bria_fibo_edit"] = ["BriaFiboEditPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -35,7 +33,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_bria_fibo import BriaFiboPipeline
from .pipeline_bria_fibo_edit import BriaFiboEditPipeline
else:
import sys

File diff suppressed because it is too large Load Diff

View File

@@ -84,6 +84,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
>>> from diffusers.utils import load_image
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
>>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
>>> controlnet = ControlNetModel.from_pretrained(

View File

@@ -53,6 +53,7 @@ EXAMPLE_DOC_STRING = """
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> from diffusers import HiDreamImagePipeline
>>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
... "meta-llama/Meta-Llama-3.1-8B-Instruct",

View File

@@ -85,6 +85,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetPAGImg2ImgPipeline, AutoencoderKL
>>> from diffusers.utils import load_image
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
>>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
>>> controlnet = ControlNetModel.from_pretrained(

View File

@@ -111,7 +111,7 @@ LIBRARIES = []
for library in LOADABLE_CLASSES:
LIBRARIES.append(library)
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device(), "cpu"]
logger = logging.get_logger(__name__)
@@ -467,8 +467,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
@@ -1187,7 +1186,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"""
self._maybe_raise_error_if_group_offload_active(raise_error=True)
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
@@ -1311,7 +1310,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
self.remove_all_hooks()
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
@@ -2200,6 +2199,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
return True
return False
def _is_pipeline_device_mapped(self):
# We support passing `device_map="cuda"`, for example. This is helpful, in case
# users want to pass `device_map="cpu"` when initializing a pipeline. This explicit declaration is desirable
# in limited VRAM environments because quantized models often initialize directly on the accelerator.
device_map = self.hf_device_map
is_device_type_map = False
if isinstance(device_map, str):
try:
torch.device(device_map)
is_device_type_map = True
except RuntimeError:
pass
return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1
class StableDiffusionMixin:
r"""

View File

@@ -459,6 +459,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
>>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline
>>> import torch
>>> pipeline = StableDiffusionPipeline.from_pretrained(
... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
... )

View File

@@ -17,51 +17,6 @@ class Flux2AutoBlocks(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinBaseAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2ModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -632,21 +587,6 @@ class AuraFlowPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class BriaFiboEditPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class BriaFiboPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -1,192 +0,0 @@
# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM
from diffusers import (
AutoencoderKLWan,
BriaFiboEditPipeline,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
from tests.pipelines.test_pipelines_common import PipelineTesterMixin
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
enable_full_determinism()
class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = BriaFiboEditPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
test_layerwise_casting = False
test_group_offloading = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = BriaFiboTransformer2DModel(
patch_size=1,
in_channels=16,
num_layers=1,
num_single_layers=1,
attention_head_dim=8,
num_attention_heads=2,
joint_attention_dim=64,
text_encoder_dim=32,
pooled_projection_dim=None,
axes_dims_rope=[0, 4, 4],
)
vae = AutoencoderKLWan(
base_dim=80,
decoder_base_dim=128,
dim_mult=[1, 2, 4, 4],
dropout=0.0,
in_channels=12,
latents_mean=[0.0] * 16,
latents_std=[1.0] * 16,
is_residual=True,
num_res_blocks=2,
out_channels=12,
patch_size=2,
scale_factor_spatial=16,
scale_factor_temporal=4,
temperal_downsample=[False, True, True],
z_dim=16,
)
scheduler = FlowMatchEulerDiscreteScheduler()
text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32))
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": '{"text": "A painting of a squirrel eating a burger","edit_instruction": "A painting of a squirrel eating a burger"}',
"negative_prompt": "bad, ugly",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 192,
"width": 336,
"output_type": "np",
}
image = Image.new("RGB", (336, 192), (255, 255, 255))
inputs["image"] = image
return inputs
@unittest.skip(reason="will not be supported due to dim-fusion")
def test_encode_prompt_works_in_isolation(self):
pass
@unittest.skip(reason="Batching is not supported yet")
def test_num_images_per_prompt(self):
pass
@unittest.skip(reason="Batching is not supported yet")
def test_inference_batch_consistent(self):
pass
@unittest.skip(reason="Batching is not supported yet")
def test_inference_batch_single_identical(self):
pass
def test_bria_fibo_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = {"edit_instruction": "a different prompt"}
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
assert max_diff > 1e-6
def test_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (64, 64), (32, 64)]
for height, width in height_width_pairs:
expected_height = height
expected_width = width
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
def test_bria_fibo_edit_mask(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
mask = Image.fromarray((np.ones((192, 336)) * 255).astype(np.uint8), mode="L")
inputs.update({"mask": mask})
output = pipe(**inputs).images[0]
assert output.shape == (192, 336, 3)
def test_bria_fibo_edit_mask_image_size_mismatch(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
mask = Image.fromarray((np.ones((64, 64)) * 255).astype(np.uint8), mode="L")
inputs.update({"mask": mask})
with self.assertRaisesRegex(ValueError, "Mask and image must have the same size"):
pipe(**inputs)
def test_bria_fibo_edit_mask_no_image(self):
pipe = self.pipeline_class(**self.get_dummy_components())
pipe = pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
mask = Image.fromarray((np.ones((32, 32)) * 255).astype(np.uint8), mode="L")
# Remove image from inputs if it's there (it shouldn't be by default from get_dummy_inputs)
inputs.pop("image", None)
inputs.update({"mask": mask})
with self.assertRaisesRegex(ValueError, "If mask is provided, image must also be provided"):
pipe(**inputs)