Compare commits

...

6 Commits

Author SHA1 Message Date
Aryan
cc4f9ab484 update docs 2025-05-29 04:37:45 +02:00
Aryan
36159dd2a6 add tests 2025-05-29 04:34:31 +02:00
Aryan
01e521a8ce Merge branch 'main' into integrations/flux-kontext 2025-05-29 04:11:51 +02:00
Aryan
d242d02e5b add example 2025-05-28 05:46:40 +02:00
Aryan
51fcdf88aa make fix-copies 2025-05-28 05:19:41 +02:00
Aryan
35bebc78db support flux kontext 2025-05-28 05:19:29 +02:00
7 changed files with 1338 additions and 0 deletions

View File

@@ -39,6 +39,7 @@ Flux comes in the following variants:
| Canny Control (LoRA) | [`black-forest-labs/FLUX.1-Canny-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) |
| Depth Control (LoRA) | [`black-forest-labs/FLUX.1-Depth-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora) |
| Redux (Adapter) | [`black-forest-labs/FLUX.1-Redux-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) |
| Kontext | [`black-forest-labs/FLUX.1-Kontext`](https://huggingface.co/black-forest-labs/FLUX.1-Kontext) |
All checkpoints have different usage which we detail below.
@@ -273,6 +274,31 @@ images = pipe(
images[0].save("flux-redux.png")
```
### Kontext
Flux Kontext is a model that allows in-context control of the image generation process, allowing for editing, refinement, relighting, style transfer, character customization, and more.
```python
import torch
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
pipe = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-kontext", transformer=transformer, torch_dtype=torch.bfloat16
)
pipe.to("cuda")
image = load_image("inputs/yarn-art-pikachu.png").convert("RGB")
prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
image = pipe(
image=image,
prompt=prompt,
guidance_scale=2.5,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("flux-kontext.png")
```
## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).

View File

@@ -375,6 +375,7 @@ else:
"FluxFillPipeline",
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
"FluxKontextPipeline",
"FluxPipeline",
"FluxPriorReduxPipeline",
"HiDreamImagePipeline",
@@ -960,6 +961,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
HiDreamImagePipeline,

View File

@@ -140,6 +140,7 @@ else:
"FluxFillPipeline",
"FluxPriorReduxPipeline",
"ReduxImageEncoder",
"FluxKontextPipeline",
]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
@@ -597,6 +598,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
ReduxImageEncoder,

View File

@@ -33,6 +33,7 @@ else:
_import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
_import_structure["pipeline_flux_kontext"] = ["FluxKontextPipeline"]
_import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -52,6 +53,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_flux_fill import FluxFillPipeline
from .pipeline_flux_img2img import FluxImg2ImgPipeline
from .pipeline_flux_inpaint import FluxInpaintPipeline
from .pipeline_flux_kontext import FluxKontextPipeline
from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
else:
import sys

File diff suppressed because it is too large Load Diff

View File

@@ -632,6 +632,21 @@ class FluxInpaintPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class FluxKontextPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class FluxPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,177 @@
import unittest
import numpy as np
import PIL.Image
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKL,
FasterCacheConfig,
FlowMatchEulerDiscreteScheduler,
FluxKontextPipeline,
FluxTransformer2DModel,
)
from diffusers.utils.testing_utils import torch_device
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
)
class FluxKontextPipelineFastTests(
unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
):
pipeline_class = FluxKontextPipeline
params = frozenset(
["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
)
batch_params = frozenset(["image", "prompt"])
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
faster_cache_config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 901),
unconditional_batch_skip_range=2,
attention_weight_callback=lambda _: 0.5,
is_guidance_distilled=True,
)
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
patch_size=1,
in_channels=4,
num_layers=num_layers,
num_single_layers=num_single_layers,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[4, 4, 8],
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
"image_encoder": None,
"feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
image = PIL.Image.new("RGB", (32, 32), 0)
inputs = {
"image": image,
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 8,
"width": 8,
"max_area": 8 * 8,
"max_sequence_length": 48,
"output_type": "np",
"_auto_resize": False,
}
return inputs
def test_flux_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
# For some reasons, they don't show large differences
assert max_diff > 1e-6
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width, "max_area": height * width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
def test_flux_true_cfg(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
inputs.pop("generator")
no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
inputs["negative_prompt"] = "bad quality"
inputs["true_cfg_scale"] = 2.0
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
assert not np.allclose(no_true_cfg_out, true_cfg_out)