mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-31 08:51:08 +08:00
Compare commits
6 Commits
dist-infer
...
integratio
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc4f9ab484 | ||
|
|
36159dd2a6 | ||
|
|
01e521a8ce | ||
|
|
d242d02e5b | ||
|
|
51fcdf88aa | ||
|
|
35bebc78db |
@@ -39,6 +39,7 @@ Flux comes in the following variants:
|
||||
| Canny Control (LoRA) | [`black-forest-labs/FLUX.1-Canny-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) |
|
||||
| Depth Control (LoRA) | [`black-forest-labs/FLUX.1-Depth-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora) |
|
||||
| Redux (Adapter) | [`black-forest-labs/FLUX.1-Redux-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) |
|
||||
| Kontext | [`black-forest-labs/FLUX.1-Kontext`](https://huggingface.co/black-forest-labs/FLUX.1-Kontext) |
|
||||
|
||||
All checkpoints have different usage which we detail below.
|
||||
|
||||
@@ -273,6 +274,31 @@ images = pipe(
|
||||
images[0].save("flux-redux.png")
|
||||
```
|
||||
|
||||
### Kontext
|
||||
|
||||
Flux Kontext is a model that allows in-context control of the image generation process, allowing for editing, refinement, relighting, style transfer, character customization, and more.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxKontextPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = FluxKontextPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-kontext", transformer=transformer, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image("inputs/yarn-art-pikachu.png").convert("RGB")
|
||||
prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
|
||||
image = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
guidance_scale=2.5,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
).images[0]
|
||||
image.save("flux-kontext.png")
|
||||
```
|
||||
|
||||
## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
|
||||
|
||||
We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).
|
||||
|
||||
@@ -375,6 +375,7 @@ else:
|
||||
"FluxFillPipeline",
|
||||
"FluxImg2ImgPipeline",
|
||||
"FluxInpaintPipeline",
|
||||
"FluxKontextPipeline",
|
||||
"FluxPipeline",
|
||||
"FluxPriorReduxPipeline",
|
||||
"HiDreamImagePipeline",
|
||||
@@ -960,6 +961,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxFillPipeline,
|
||||
FluxImg2ImgPipeline,
|
||||
FluxInpaintPipeline,
|
||||
FluxKontextPipeline,
|
||||
FluxPipeline,
|
||||
FluxPriorReduxPipeline,
|
||||
HiDreamImagePipeline,
|
||||
|
||||
@@ -140,6 +140,7 @@ else:
|
||||
"FluxFillPipeline",
|
||||
"FluxPriorReduxPipeline",
|
||||
"ReduxImageEncoder",
|
||||
"FluxKontextPipeline",
|
||||
]
|
||||
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
||||
_import_structure["audioldm2"] = [
|
||||
@@ -597,6 +598,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxFillPipeline,
|
||||
FluxImg2ImgPipeline,
|
||||
FluxInpaintPipeline,
|
||||
FluxKontextPipeline,
|
||||
FluxPipeline,
|
||||
FluxPriorReduxPipeline,
|
||||
ReduxImageEncoder,
|
||||
|
||||
@@ -33,6 +33,7 @@ else:
|
||||
_import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
|
||||
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
|
||||
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
|
||||
_import_structure["pipeline_flux_kontext"] = ["FluxKontextPipeline"]
|
||||
_import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -52,6 +53,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_flux_fill import FluxFillPipeline
|
||||
from .pipeline_flux_img2img import FluxImg2ImgPipeline
|
||||
from .pipeline_flux_inpaint import FluxInpaintPipeline
|
||||
from .pipeline_flux_kontext import FluxKontextPipeline
|
||||
from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
1114
src/diffusers/pipelines/flux/pipeline_flux_kontext.py
Normal file
1114
src/diffusers/pipelines/flux/pipeline_flux_kontext.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -632,6 +632,21 @@ class FluxInpaintPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class FluxKontextPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class FluxPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
177
tests/pipelines/flux/test_pipeline_flux_kontext.py
Normal file
177
tests/pipelines/flux/test_pipeline_flux_kontext.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FasterCacheConfig,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
FluxKontextPipeline,
|
||||
FluxTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
class FluxKontextPipelineFastTests(
|
||||
unittest.TestCase,
|
||||
PipelineTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
):
|
||||
pipeline_class = FluxKontextPipeline
|
||||
params = frozenset(
|
||||
["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
|
||||
)
|
||||
batch_params = frozenset(["image", "prompt"])
|
||||
|
||||
# there is no xformers processor for Flux
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
faster_cache_config = FasterCacheConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(-1, 901),
|
||||
unconditional_batch_skip_range=2,
|
||||
attention_weight_callback=lambda _: 0.5,
|
||||
is_guidance_distilled=True,
|
||||
)
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = FluxTransformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
num_layers=num_layers,
|
||||
num_single_layers=num_single_layers,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=2,
|
||||
joint_attention_dim=32,
|
||||
pooled_projection_dim=32,
|
||||
axes_dims_rope=[4, 4, 8],
|
||||
)
|
||||
clip_text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=1,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
shift_factor=0.0609,
|
||||
scaling_factor=1.5035,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer": tokenizer,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"image_encoder": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
image = PIL.Image.new("RGB", (32, 32), 0)
|
||||
inputs = {
|
||||
"image": image,
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"max_area": 8 * 8,
|
||||
"max_sequence_length": 48,
|
||||
"output_type": "np",
|
||||
"_auto_resize": False,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_flux_different_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt_2"] = "a different prompt"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# Outputs should be different here
|
||||
# For some reasons, they don't show large differences
|
||||
assert max_diff > 1e-6
|
||||
|
||||
def test_flux_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
height_width_pairs = [(32, 32), (72, 57)]
|
||||
for height, width in height_width_pairs:
|
||||
expected_height = height - height % (pipe.vae_scale_factor * 2)
|
||||
expected_width = width - width % (pipe.vae_scale_factor * 2)
|
||||
|
||||
inputs.update({"height": height, "width": width, "max_area": height * width})
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
assert (output_height, output_width) == (expected_height, expected_width)
|
||||
|
||||
def test_flux_true_cfg(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs.pop("generator")
|
||||
|
||||
no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
|
||||
inputs["negative_prompt"] = "bad quality"
|
||||
inputs["true_cfg_scale"] = 2.0
|
||||
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
|
||||
assert not np.allclose(no_true_cfg_out, true_cfg_out)
|
||||
Reference in New Issue
Block a user