mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 13:04:15 +08:00
Compare commits
5 Commits
remove-tor
...
flux-singl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b40812b18f | ||
|
|
27910ee691 | ||
|
|
15924bc73b | ||
|
|
a64e8a333a | ||
|
|
b226d67d1d |
@@ -77,6 +77,59 @@ out = pipe(
|
|||||||
out.save("image.png")
|
out.save("image.png")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Single File Loading for the `FluxTransformer2DModel`
|
||||||
|
|
||||||
|
The `FluxTransformer2DModel` supports loading checkpoints in the original format shipped by Black Forest Labs. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
`FP8` inference can be brittle depending on the GPU type, CUDA version, and `torch` version that you are using. It is recommended that you use the `optimum-quanto` library in order to run FP8 inference on your machine.
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
The following example demonstrates how to run Flux with less than 16GB of VRAM.
|
||||||
|
|
||||||
|
First install `optimum-quanto`
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install optimum-quanto
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run the following example
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import FluxTransformer2DModel, FluxPipeline
|
||||||
|
from transformers import T5EncoderModel, CLIPTextModel
|
||||||
|
from optimum.quanto import freeze, qfloat8, quantize
|
||||||
|
|
||||||
|
bfl_repo = "black-forest-labs/FLUX.1-dev"
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype)
|
||||||
|
quantize(transformer, weights=qfloat8)
|
||||||
|
freeze(transformer)
|
||||||
|
|
||||||
|
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||||
|
quantize(text_encoder_2, weights=qfloat8)
|
||||||
|
freeze(text_encoder_2)
|
||||||
|
|
||||||
|
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
|
||||||
|
pipe.transformer = transformer
|
||||||
|
pipe.text_encoder_2 = text_encoder_2
|
||||||
|
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
|
||||||
|
prompt = "A cat holding a sign that says hello world"
|
||||||
|
image = pipe(
|
||||||
|
prompt,
|
||||||
|
guidance_scale=3.5,
|
||||||
|
output_type="pil",
|
||||||
|
num_inference_steps=20,
|
||||||
|
generator=torch.Generator("cpu").manual_seed(0)
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
image.save("flux-fp8-dev.png")
|
||||||
|
```
|
||||||
|
|
||||||
## FluxPipeline
|
## FluxPipeline
|
||||||
|
|
||||||
[[autodoc]] FluxPipeline
|
[[autodoc]] FluxPipeline
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from .single_file_utils import (
|
|||||||
SingleFileComponentError,
|
SingleFileComponentError,
|
||||||
convert_animatediff_checkpoint_to_diffusers,
|
convert_animatediff_checkpoint_to_diffusers,
|
||||||
convert_controlnet_checkpoint,
|
convert_controlnet_checkpoint,
|
||||||
|
convert_flux_transformer_checkpoint_to_diffusers,
|
||||||
convert_ldm_unet_checkpoint,
|
convert_ldm_unet_checkpoint,
|
||||||
convert_ldm_vae_checkpoint,
|
convert_ldm_vae_checkpoint,
|
||||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||||
@@ -74,6 +75,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
|||||||
"MotionAdapter": {
|
"MotionAdapter": {
|
||||||
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
||||||
},
|
},
|
||||||
|
"FluxTransformer2DModel": {
|
||||||
|
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
|
||||||
|
"default_subfolder": "transformer",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ CHECKPOINT_KEY_NAMES = {
|
|||||||
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe",
|
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe",
|
||||||
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
|
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
|
||||||
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
||||||
|
"flux": "double_blocks.0.img_attn.norm.key_norm.scale",
|
||||||
}
|
}
|
||||||
|
|
||||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||||
@@ -110,6 +111,8 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|||||||
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
|
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
|
||||||
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
|
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
|
||||||
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
|
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
|
||||||
|
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
|
||||||
|
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Use to configure model sample size when original config is provided
|
# Use to configure model sample size when original config is provided
|
||||||
@@ -503,6 +506,11 @@ def infer_diffusers_model_type(checkpoint):
|
|||||||
else:
|
else:
|
||||||
model_type = "animatediff_v3"
|
model_type = "animatediff_v3"
|
||||||
|
|
||||||
|
elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint:
|
||||||
|
if "guidance_in.in_layer.bias" in checkpoint:
|
||||||
|
model_type = "flux-dev"
|
||||||
|
else:
|
||||||
|
model_type = "flux-schnell"
|
||||||
else:
|
else:
|
||||||
model_type = "v1"
|
model_type = "v1"
|
||||||
|
|
||||||
@@ -1859,3 +1867,195 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
|
|||||||
] = v
|
] = v
|
||||||
|
|
||||||
return converted_state_dict
|
return converted_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||||
|
converted_state_dict = {}
|
||||||
|
|
||||||
|
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
|
||||||
|
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
|
||||||
|
mlp_ratio = 4.0
|
||||||
|
inner_dim = 3072
|
||||||
|
|
||||||
|
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
|
||||||
|
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
|
||||||
|
def swap_scale_shift(weight):
|
||||||
|
shift, scale = weight.chunk(2, dim=0)
|
||||||
|
new_weight = torch.cat([scale, shift], dim=0)
|
||||||
|
return new_weight
|
||||||
|
|
||||||
|
## time_text_embed.timestep_embedder <- time_in
|
||||||
|
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
|
||||||
|
"time_in.in_layer.weight"
|
||||||
|
)
|
||||||
|
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
|
||||||
|
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
|
||||||
|
"time_in.out_layer.weight"
|
||||||
|
)
|
||||||
|
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")
|
||||||
|
|
||||||
|
## time_text_embed.text_embedder <- vector_in
|
||||||
|
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
|
||||||
|
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
|
||||||
|
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
|
||||||
|
"vector_in.out_layer.weight"
|
||||||
|
)
|
||||||
|
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")
|
||||||
|
|
||||||
|
# guidance
|
||||||
|
has_guidance = any("guidance" in k for k in checkpoint)
|
||||||
|
if has_guidance:
|
||||||
|
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
|
||||||
|
"guidance_in.in_layer.weight"
|
||||||
|
)
|
||||||
|
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
|
||||||
|
"guidance_in.in_layer.bias"
|
||||||
|
)
|
||||||
|
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
|
||||||
|
"guidance_in.out_layer.weight"
|
||||||
|
)
|
||||||
|
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
|
||||||
|
"guidance_in.out_layer.bias"
|
||||||
|
)
|
||||||
|
|
||||||
|
# context_embedder
|
||||||
|
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
|
||||||
|
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
|
||||||
|
|
||||||
|
# x_embedder
|
||||||
|
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
|
||||||
|
converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
|
||||||
|
|
||||||
|
# double transformer blocks
|
||||||
|
for i in range(num_layers):
|
||||||
|
block_prefix = f"transformer_blocks.{i}."
|
||||||
|
# norms.
|
||||||
|
## norm1
|
||||||
|
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
|
||||||
|
f"double_blocks.{i}.img_mod.lin.weight"
|
||||||
|
)
|
||||||
|
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
|
||||||
|
f"double_blocks.{i}.img_mod.lin.bias"
|
||||||
|
)
|
||||||
|
## norm1_context
|
||||||
|
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
|
||||||
|
f"double_blocks.{i}.txt_mod.lin.weight"
|
||||||
|
)
|
||||||
|
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
|
||||||
|
f"double_blocks.{i}.txt_mod.lin.bias"
|
||||||
|
)
|
||||||
|
# Q, K, V
|
||||||
|
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
|
||||||
|
context_q, context_k, context_v = torch.chunk(
|
||||||
|
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
|
||||||
|
)
|
||||||
|
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
||||||
|
checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
|
||||||
|
)
|
||||||
|
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
||||||
|
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
|
||||||
|
)
|
||||||
|
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
|
||||||
|
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
|
||||||
|
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
|
||||||
|
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
|
||||||
|
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
|
||||||
|
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
|
||||||
|
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
|
||||||
|
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
|
||||||
|
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
|
||||||
|
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
|
||||||
|
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
|
||||||
|
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
|
||||||
|
# qk_norm
|
||||||
|
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
|
||||||
|
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
|
||||||
|
)
|
||||||
|
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
|
||||||
|
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
|
||||||
|
)
|
||||||
|
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
|
||||||
|
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
|
||||||
|
)
|
||||||
|
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
|
||||||
|
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
|
||||||
|
)
|
||||||
|
# ff img_mlp
|
||||||
|
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
|
||||||
|
f"double_blocks.{i}.img_mlp.0.weight"
|
||||||
|
)
|
||||||
|
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
|
||||||
|
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
|
||||||
|
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
|
||||||
|
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
|
||||||
|
f"double_blocks.{i}.txt_mlp.0.weight"
|
||||||
|
)
|
||||||
|
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
|
||||||
|
f"double_blocks.{i}.txt_mlp.0.bias"
|
||||||
|
)
|
||||||
|
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
|
||||||
|
f"double_blocks.{i}.txt_mlp.2.weight"
|
||||||
|
)
|
||||||
|
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
|
||||||
|
f"double_blocks.{i}.txt_mlp.2.bias"
|
||||||
|
)
|
||||||
|
# output projections.
|
||||||
|
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
|
||||||
|
f"double_blocks.{i}.img_attn.proj.weight"
|
||||||
|
)
|
||||||
|
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
|
||||||
|
f"double_blocks.{i}.img_attn.proj.bias"
|
||||||
|
)
|
||||||
|
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
|
||||||
|
f"double_blocks.{i}.txt_attn.proj.weight"
|
||||||
|
)
|
||||||
|
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
|
||||||
|
f"double_blocks.{i}.txt_attn.proj.bias"
|
||||||
|
)
|
||||||
|
|
||||||
|
# single transfomer blocks
|
||||||
|
for i in range(num_single_layers):
|
||||||
|
block_prefix = f"single_transformer_blocks.{i}."
|
||||||
|
# norm.linear <- single_blocks.0.modulation.lin
|
||||||
|
converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
|
||||||
|
f"single_blocks.{i}.modulation.lin.weight"
|
||||||
|
)
|
||||||
|
converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
|
||||||
|
f"single_blocks.{i}.modulation.lin.bias"
|
||||||
|
)
|
||||||
|
# Q, K, V, mlp
|
||||||
|
mlp_hidden_dim = int(inner_dim * mlp_ratio)
|
||||||
|
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
|
||||||
|
q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
|
||||||
|
q_bias, k_bias, v_bias, mlp_bias = torch.split(
|
||||||
|
checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
|
||||||
|
)
|
||||||
|
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
|
||||||
|
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
|
||||||
|
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
|
||||||
|
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
|
||||||
|
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
|
||||||
|
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
|
||||||
|
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
|
||||||
|
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
|
||||||
|
# qk norm
|
||||||
|
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
|
||||||
|
f"single_blocks.{i}.norm.query_norm.scale"
|
||||||
|
)
|
||||||
|
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
|
||||||
|
f"single_blocks.{i}.norm.key_norm.scale"
|
||||||
|
)
|
||||||
|
# output projections.
|
||||||
|
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
|
||||||
|
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
|
||||||
|
|
||||||
|
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||||
|
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||||
|
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
|
||||||
|
checkpoint.pop("final_layer.adaLN_modulation.1.weight")
|
||||||
|
)
|
||||||
|
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
|
||||||
|
checkpoint.pop("final_layer.adaLN_modulation.1.bias")
|
||||||
|
)
|
||||||
|
|
||||||
|
return converted_state_dict
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...loaders import PeftAdapterMixin
|
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||||
from ...models.attention import FeedForward
|
from ...models.attention import FeedForward
|
||||||
from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
|
from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
|
||||||
from ...models.modeling_utils import ModelMixin
|
from ...models.modeling_utils import ModelMixin
|
||||||
@@ -227,7 +227,7 @@ class FluxTransformerBlock(nn.Module):
|
|||||||
return encoder_hidden_states, hidden_states
|
return encoder_hidden_states, hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||||
"""
|
"""
|
||||||
The Transformer model introduced in Flux.
|
The Transformer model introduced in Flux.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user