mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-22 19:45:47 +08:00
Compare commits
14 Commits
modular-kl
...
devanshi00
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c967e88afc | ||
|
|
153b310f2d | ||
|
|
bef2ba43e5 | ||
|
|
58791aaa15 | ||
|
|
40f46d9ec7 | ||
|
|
0d9b1b70e1 | ||
|
|
69035fdd76 | ||
|
|
f70e9cf071 | ||
|
|
405b40006c | ||
|
|
a91e3040ac | ||
|
|
3bc3fdb035 | ||
|
|
8cc38a75d3 | ||
|
|
e5bb10cfe1 | ||
|
|
ec541906c5 |
@@ -413,9 +413,6 @@ else:
|
||||
_import_structure["modular_pipelines"].extend(
|
||||
[
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2KleinAutoBlocks",
|
||||
"Flux2KleinBaseAutoBlocks",
|
||||
"Flux2KleinModularPipeline",
|
||||
"Flux2ModularPipeline",
|
||||
"FluxAutoBlocks",
|
||||
"FluxKontextAutoBlocks",
|
||||
@@ -1149,9 +1146,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .modular_pipelines import (
|
||||
Flux2AutoBlocks,
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinModularPipeline,
|
||||
Flux2ModularPipeline,
|
||||
FluxAutoBlocks,
|
||||
FluxKontextAutoBlocks,
|
||||
|
||||
@@ -675,6 +675,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
variant: Optional[str] = None,
|
||||
max_shard_size: Union[int, str] = "10GB",
|
||||
push_to_hub: bool = False,
|
||||
use_flashpack: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -707,6 +708,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. FlashPack is a
|
||||
binary format that allows for faster loading. Requires the `flashpack` library to be installed.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
@@ -727,12 +731,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
" the logger on the traceback to understand the reason why the quantized model is not serializable."
|
||||
)
|
||||
|
||||
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
weights_name = _add_variant(weights_name, variant)
|
||||
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if push_to_hub:
|
||||
@@ -746,67 +744,80 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# Only save the model itself if we are using distributed training
|
||||
model_to_save = self
|
||||
|
||||
# Attach architecture to the config
|
||||
# Save the config
|
||||
if is_main_process:
|
||||
model_to_save.save_config(save_directory)
|
||||
|
||||
# Save the model
|
||||
state_dict = model_to_save.state_dict()
|
||||
if use_flashpack:
|
||||
if not is_main_process:
|
||||
return
|
||||
|
||||
# Save the model
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
||||
)
|
||||
from ..utils.flashpack_utils import save_flashpack
|
||||
|
||||
# Clean the folder from a previous save
|
||||
if is_main_process:
|
||||
for filename in os.listdir(save_directory):
|
||||
if filename in state_dict_split.filename_to_tensors.keys():
|
||||
continue
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
if not os.path.isfile(full_filename):
|
||||
continue
|
||||
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
||||
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
||||
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
||||
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
||||
if (
|
||||
filename.startswith(weights_without_ext)
|
||||
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
||||
):
|
||||
os.remove(full_filename)
|
||||
|
||||
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||
filepath = os.path.join(save_directory, filename)
|
||||
if safe_serialization:
|
||||
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
||||
# joyfulness), but for now this enough.
|
||||
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(shard, filepath)
|
||||
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
logger.info(
|
||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
save_flashpack(model_to_save, save_directory, variant=variant)
|
||||
else:
|
||||
path_to_weights = os.path.join(save_directory, weights_name)
|
||||
logger.info(f"Model weights saved in {path_to_weights}")
|
||||
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
weights_name = _add_variant(weights_name, variant)
|
||||
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
|
||||
state_dict = model_to_save.state_dict()
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
||||
)
|
||||
|
||||
# Clean the folder from a previous save
|
||||
if is_main_process:
|
||||
for filename in os.listdir(save_directory):
|
||||
if filename in state_dict_split.filename_to_tensors.keys():
|
||||
continue
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
if not os.path.isfile(full_filename):
|
||||
continue
|
||||
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
||||
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
||||
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
||||
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
||||
if (
|
||||
filename.startswith(weights_without_ext)
|
||||
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
||||
):
|
||||
os.remove(full_filename)
|
||||
|
||||
# Save each shard
|
||||
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||
filepath = os.path.join(save_directory, filename)
|
||||
if safe_serialization:
|
||||
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
||||
# joyfulness), but for now this enough.
|
||||
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(shard, filepath)
|
||||
|
||||
# Save index file if sharded
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
logger.info(
|
||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
else:
|
||||
path_to_weights = os.path.join(save_directory, weights_name)
|
||||
logger.info(f"Model weights saved in {path_to_weights}")
|
||||
|
||||
# Push to hub if requested (common to both paths)
|
||||
if push_to_hub:
|
||||
# Create a new empty model card and eventually tag it
|
||||
model_card = load_or_create_model_card(repo_id, token=token)
|
||||
@@ -939,6 +950,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
||||
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
||||
weights. If set to `False`, `safetensors` weights are not loaded.
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack)
|
||||
weights if a compatible `.flashpack` file is found. If flashpack is unavailable or the `.flashpack`
|
||||
file cannot be used, automatic fallback to the standard loading path (for example, `safetensors`).
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
@@ -982,6 +997,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_flashpack = kwargs.pop("use_flashpack", False)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
@@ -1199,7 +1215,31 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
||||
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
|
||||
else:
|
||||
|
||||
flashpack_file = None
|
||||
if use_flashpack:
|
||||
try:
|
||||
flashpack_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant("model.flashpack", variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
except EnvironmentError:
|
||||
flashpack_file = None
|
||||
logger.warning(
|
||||
"`use_flashpack` was specified to be True but not flashpack file was found. Resorting to non-flashpack alternatives."
|
||||
)
|
||||
|
||||
if flashpack_file is None:
|
||||
# in the case it is sharded, we have already the index
|
||||
if is_sharded:
|
||||
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
|
||||
@@ -1215,6 +1255,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
elif use_safetensors:
|
||||
logger.warning("Trying to load model weights with safetensors format.")
|
||||
try:
|
||||
resolved_model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
@@ -1280,6 +1321,29 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
if dtype_orig is not None:
|
||||
torch.set_default_dtype(dtype_orig)
|
||||
|
||||
if flashpack_file is not None:
|
||||
from ..utils.flashpack_utils import load_flashpack
|
||||
|
||||
# Even when using FlashPack, we preserve `low_cpu_mem_usage` behavior by initializing
|
||||
# the model with meta tensors. Since FlashPack cannot write into meta tensors, we
|
||||
# explicitly materialize parameters before loading to ensure correctness and parity
|
||||
# with the standard loading path.
|
||||
if any(p.device.type == "meta" for p in model.parameters()):
|
||||
model.to_empty(device="cpu")
|
||||
load_flashpack(model, flashpack_file)
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
model.eval()
|
||||
|
||||
if output_loading_info:
|
||||
return model, {
|
||||
"missing_keys": [],
|
||||
"unexpected_keys": [],
|
||||
"mismatched_keys": [],
|
||||
"error_msgs": [],
|
||||
}
|
||||
|
||||
return model
|
||||
|
||||
state_dict = None
|
||||
if not is_sharded:
|
||||
# Time to load the checkpoint
|
||||
@@ -1327,7 +1391,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
dduf_entries=dduf_entries,
|
||||
is_parallel_loading_enabled=is_parallel_loading_enabled,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
@@ -1373,6 +1436,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
if output_loading_info:
|
||||
return model, loading_info
|
||||
|
||||
logger.warning(f"Model till end {pretrained_model_name_or_path} loaded successfully")
|
||||
|
||||
return model
|
||||
|
||||
# Adapted from `transformers`.
|
||||
|
||||
@@ -54,10 +54,7 @@ else:
|
||||
]
|
||||
_import_structure["flux2"] = [
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2KleinAutoBlocks",
|
||||
"Flux2KleinBaseAutoBlocks",
|
||||
"Flux2ModularPipeline",
|
||||
"Flux2KleinModularPipeline",
|
||||
]
|
||||
_import_structure["qwenimage"] = [
|
||||
"QwenImageAutoBlocks",
|
||||
@@ -84,13 +81,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .components_manager import ComponentsManager
|
||||
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
|
||||
from .flux2 import (
|
||||
Flux2AutoBlocks,
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinModularPipeline,
|
||||
Flux2ModularPipeline,
|
||||
)
|
||||
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
|
||||
from .modular_pipeline import (
|
||||
AutoPipelineBlocks,
|
||||
BlockState,
|
||||
|
||||
@@ -43,7 +43,7 @@ else:
|
||||
"Flux2ProcessImagesInputStep",
|
||||
"Flux2TextInputStep",
|
||||
]
|
||||
_import_structure["modular_blocks_flux2"] = [
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"AUTO_BLOCKS",
|
||||
"REMOTE_AUTO_BLOCKS",
|
||||
@@ -51,11 +51,10 @@ else:
|
||||
"IMAGE_CONDITIONED_BLOCKS",
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2AutoVaeEncoderStep",
|
||||
"Flux2CoreDenoiseStep",
|
||||
"Flux2BeforeDenoiseStep",
|
||||
"Flux2VaeEncoderSequentialStep",
|
||||
]
|
||||
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
|
||||
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]
|
||||
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -86,7 +85,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Flux2ProcessImagesInputStep,
|
||||
Flux2TextInputStep,
|
||||
)
|
||||
from .modular_blocks_flux2 import (
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
AUTO_BLOCKS,
|
||||
IMAGE_CONDITIONED_BLOCKS,
|
||||
@@ -94,14 +93,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
Flux2AutoBlocks,
|
||||
Flux2AutoVaeEncoderStep,
|
||||
Flux2CoreDenoiseStep,
|
||||
Flux2BeforeDenoiseStep,
|
||||
Flux2VaeEncoderSequentialStep,
|
||||
)
|
||||
from .modular_blocks_flux2_klein import (
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
)
|
||||
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -129,9 +129,17 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
InputParam("num_inference_steps", default=50),
|
||||
InputParam("timesteps"),
|
||||
InputParam("sigmas"),
|
||||
InputParam("guidance_scale", default=4.0),
|
||||
InputParam("latents", type_hint=torch.Tensor),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -143,12 +151,13 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
type_hint=int,
|
||||
description="The number of denoising steps to perform at inference time",
|
||||
),
|
||||
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
block_state.device = components._execution_device
|
||||
|
||||
scheduler = components.scheduler
|
||||
|
||||
@@ -174,7 +183,7 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
block_state.device,
|
||||
timesteps=timesteps,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
@@ -182,6 +191,11 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = num_inference_steps
|
||||
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
|
||||
guidance = guidance.expand(batch_size)
|
||||
block_state.guidance = guidance
|
||||
|
||||
components.scheduler.set_begin_index(0)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
@@ -339,6 +353,7 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt_embeds", required=True),
|
||||
InputParam(name="latent_ids"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -350,6 +365,12 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
|
||||
),
|
||||
OutputParam(
|
||||
name="latent_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -382,72 +403,6 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinBaseRoPEInputsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt_embeds", required=True),
|
||||
InputParam(name="negative_prompt_embeds", required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="txt_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for negative text tokens, used for RoPE calculation.",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None):
|
||||
"""Prepare 4D position IDs for text tokens."""
|
||||
B, L, _ = x.shape
|
||||
out_ids = []
|
||||
|
||||
for i in range(B):
|
||||
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||
h = torch.arange(1)
|
||||
w = torch.arange(1)
|
||||
seq_l = torch.arange(L)
|
||||
|
||||
coords = torch.cartesian_prod(t, h, w, seq_l)
|
||||
out_ids.append(coords)
|
||||
|
||||
return torch.stack(out_ids)
|
||||
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
prompt_embeds = block_state.prompt_embeds
|
||||
device = prompt_embeds.device
|
||||
|
||||
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
|
||||
block_state.txt_ids = block_state.txt_ids.to(device)
|
||||
|
||||
block_state.negative_txt_ids = None
|
||||
if block_state.negative_prompt_embeds is not None:
|
||||
block_state.negative_txt_ids = self._prepare_text_ids(block_state.negative_prompt_embeds)
|
||||
block_state.negative_txt_ids = block_state.negative_txt_ids.to(device)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@@ -551,42 +506,3 @@ class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2PrepareGuidanceStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the guidance scale tensor for Flux2 inference"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("guidance_scale", default=4.0),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
guidance = torch.full([1], block_state.guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(batch_size)
|
||||
block_state.guidance = guidance
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -29,16 +29,29 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class Flux2UnpackLatentsStep(ModularPipelineBlocks):
|
||||
class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLFlux2),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
Flux2ImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that unpacks the latents from the denoising step"
|
||||
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -57,9 +70,9 @@ class Flux2UnpackLatentsStep(ModularPipelineBlocks):
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoise latents from denoising step, unpacked with position IDs.",
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
|
||||
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
)
|
||||
]
|
||||
|
||||
@@ -94,62 +107,6 @@ class Flux2UnpackLatentsStep(ModularPipelineBlocks):
|
||||
|
||||
return torch.stack(x_list, dim=0)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
latents = block_state.latents
|
||||
latent_ids = block_state.latent_ids
|
||||
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
|
||||
block_state.latents = latents
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLFlux2),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
Flux2ImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
|
||||
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _unpatchify_latents(latents):
|
||||
"""Convert patchified latents back to regular format."""
|
||||
@@ -164,20 +121,26 @@ class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
block_state = self.get_block_state(state)
|
||||
vae = components.vae
|
||||
|
||||
latents = block_state.latents
|
||||
if block_state.output_type == "latent":
|
||||
block_state.images = block_state.latents
|
||||
else:
|
||||
latents = block_state.latents
|
||||
latent_ids = block_state.latent_ids
|
||||
|
||||
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents * latents_bn_std + latents_bn_mean
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
|
||||
latents = self._unpatchify_latents(latents)
|
||||
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents * latents_bn_std + latents_bn_mean
|
||||
|
||||
block_state.images = vae.decode(latents, return_dict=False)[0]
|
||||
block_state.images = components.image_processor.postprocess(
|
||||
block_state.images, output_type=block_state.output_type
|
||||
)
|
||||
latents = self._unpatchify_latents(latents)
|
||||
|
||||
block_state.images = vae.decode(latents, return_dict=False)[0]
|
||||
block_state.images = components.image_processor.postprocess(
|
||||
block_state.images, output_type=block_state.output_type
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -16,8 +16,6 @@ from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models import Flux2Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging
|
||||
@@ -27,8 +25,8 @@ from ..modular_pipeline import (
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -136,229 +134,6 @@ class Flux2LoopDenoiser(ModularPipelineBlocks):
|
||||
return components, block_state
|
||||
|
||||
|
||||
# same as Flux2LoopDenoiser but guidance=None
|
||||
class Flux2KleinLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("transformer", Flux2Transformer2DModel)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that denoises the latents for Flux2. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `Flux2DenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("joint_attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to denoise. Shape: (B, seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latent_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
|
||||
),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from Qwen3",
|
||||
),
|
||||
InputParam(
|
||||
"txt_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for text tokens (T, H, W, L)",
|
||||
),
|
||||
InputParam(
|
||||
"latent_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for latent tokens (T, H, W, L)",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
latents = block_state.latents
|
||||
latent_model_input = latents.to(components.transformer.dtype)
|
||||
img_ids = block_state.latent_ids
|
||||
|
||||
image_latents = getattr(block_state, "image_latents", None)
|
||||
if image_latents is not None:
|
||||
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
|
||||
image_latent_ids = block_state.image_latent_ids
|
||||
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
|
||||
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = components.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=block_state.prompt_embeds,
|
||||
txt_ids=block_state.txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=block_state.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
noise_pred = noise_pred[:, : latents.size(1)]
|
||||
block_state.noise_pred = noise_pred
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
# support CFG for Flux2-Klein base model
|
||||
class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("transformer", Flux2Transformer2DModel),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(name="is_distilled", default=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that denoises the latents for Flux2. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `Flux2DenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("joint_attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to denoise. Shape: (B, seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latent_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
|
||||
),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from Qwen3",
|
||||
),
|
||||
InputParam(
|
||||
"negative_prompt_embeds",
|
||||
required=False,
|
||||
type_hint=torch.Tensor,
|
||||
description="Negative text embeddings from Qwen3",
|
||||
),
|
||||
InputParam(
|
||||
"txt_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for text tokens (T, H, W, L)",
|
||||
),
|
||||
InputParam(
|
||||
"negative_txt_ids",
|
||||
required=False,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for negative text tokens (T, H, W, L)",
|
||||
),
|
||||
InputParam(
|
||||
"latent_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for latent tokens (T, H, W, L)",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
latents = block_state.latents
|
||||
latent_model_input = latents.to(components.transformer.dtype)
|
||||
img_ids = block_state.latent_ids
|
||||
|
||||
image_latents = getattr(block_state, "image_latents", None)
|
||||
if image_latents is not None:
|
||||
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
|
||||
image_latent_ids = block_state.image_latent_ids
|
||||
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
|
||||
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
guider_inputs = {
|
||||
"encoder_hidden_states": (
|
||||
getattr(block_state, "prompt_embeds", None),
|
||||
getattr(block_state, "negative_prompt_embeds", None),
|
||||
),
|
||||
"txt_ids": (
|
||||
getattr(block_state, "txt_ids", None),
|
||||
getattr(block_state, "negative_txt_ids", None),
|
||||
),
|
||||
}
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
for guider_state_batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
|
||||
|
||||
noise_pred = components.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=block_state.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)[0]
|
||||
guider_state_batch.noise_pred = noise_pred[:, : latents.size(1)]
|
||||
components.guider.cleanup_models(components.transformer)
|
||||
|
||||
# perform guidance
|
||||
block_state.noise_pred = components.guider(guider_state)[0]
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class Flux2LoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@@ -475,35 +250,3 @@ class Flux2DenoiseStep(Flux2DenoiseLoopWrapper):
|
||||
" - `Flux2LoopAfterDenoiser`\n"
|
||||
"This block supports both text-to-image and image-conditioned generation."
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper):
|
||||
block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser]
|
||||
block_names = ["denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoises the latents for Flux2. \n"
|
||||
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `Flux2KleinLoopDenoiser`\n"
|
||||
" - `Flux2LoopAfterDenoiser`\n"
|
||||
"This block supports both text-to-image and image-conditioned generation."
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinBaseDenoiseStep(Flux2DenoiseLoopWrapper):
|
||||
block_classes = [Flux2KleinBaseLoopDenoiser, Flux2LoopAfterDenoiser]
|
||||
block_names = ["denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoises the latents for Flux2. \n"
|
||||
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `Flux2KleinBaseLoopDenoiser`\n"
|
||||
" - `Flux2LoopAfterDenoiser`\n"
|
||||
"This block supports both text-to-image and image-conditioned generation."
|
||||
)
|
||||
|
||||
@@ -15,15 +15,13 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM
|
||||
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models import AutoencoderKLFlux2
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -81,8 +79,10 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
|
||||
InputParam("joint_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -99,7 +99,14 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
prompt_embeds = getattr(block_state, "prompt_embeds", None)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||
"Please make sure to only forward one of the two."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
@@ -158,6 +165,10 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
block_state.device = components._execution_device
|
||||
|
||||
if block_state.prompt_embeds is not None:
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
@@ -194,6 +205,7 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -210,8 +222,15 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
||||
prompt_embeds = getattr(block_state, "prompt_embeds", None)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||
"Please make sure to only forward one of the two."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
@@ -225,6 +244,10 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
block_state.device = components._execution_device
|
||||
|
||||
if block_state.prompt_embeds is not None:
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
@@ -247,289 +270,6 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Qwen3ForCausalLM),
|
||||
ComponentSpec("tokenizer", Qwen2TokenizerFast),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(name="is_distilled", default=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from qwen3 used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds
|
||||
def _get_qwen3_prompt_embeds(
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
prompt: Union[str, List[str]],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
hidden_states_layers: List[int] = (9, 18, 27),
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
all_input_ids = []
|
||||
all_attention_masks = []
|
||||
|
||||
for single_prompt in prompt:
|
||||
messages = [{"role": "user", "content": single_prompt}]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
all_input_ids.append(inputs["input_ids"])
|
||||
all_attention_masks.append(inputs["attention_mask"])
|
||||
|
||||
input_ids = torch.cat(all_input_ids, dim=0).to(device)
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Only use outputs from intermediate layers and stack them
|
||||
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
block_state.prompt_embeds = self._get_qwen3_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
hidden_states_layers=block_state.text_encoder_out_layers,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Qwen3ForCausalLM),
|
||||
ComponentSpec("tokenizer", Qwen2TokenizerFast),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_configs(self) -> List[ConfigSpec]:
|
||||
return [
|
||||
ConfigSpec(name="is_distilled", default=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from qwen3 used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Negative text embeddings from qwen3 used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds
|
||||
def _get_qwen3_prompt_embeds(
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
prompt: Union[str, List[str]],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
hidden_states_layers: List[int] = (9, 18, 27),
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
all_input_ids = []
|
||||
all_attention_masks = []
|
||||
|
||||
for single_prompt in prompt:
|
||||
messages = [{"role": "user", "content": single_prompt}]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
all_input_ids.append(inputs["input_ids"])
|
||||
all_attention_masks.append(inputs["attention_mask"])
|
||||
|
||||
input_ids = torch.cat(all_input_ids, dim=0).to(device)
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Only use outputs from intermediate layers and stack them
|
||||
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
block_state.prompt_embeds = self._get_qwen3_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
hidden_states_layers=block_state.text_encoder_out_layers,
|
||||
)
|
||||
|
||||
if components.requires_unconditional_embeds:
|
||||
negative_prompt = [""] * len(prompt)
|
||||
block_state.negative_prompt_embeds = self._get_qwen3_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
hidden_states_layers=block_state.text_encoder_out_layers,
|
||||
)
|
||||
else:
|
||||
block_state.negative_prompt_embeds = None
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2VaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class Flux2TextInputStep(ModularPipelineBlocks):
|
||||
required=True,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
|
||||
description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -89,90 +89,6 @@ class Flux2TextInputStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2KleinBaseTextInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step:\n"
|
||||
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
|
||||
),
|
||||
InputParam(
|
||||
"negative_prompt_embeds",
|
||||
required=False,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"batch_size",
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
|
||||
),
|
||||
OutputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||
),
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Text embeddings used to guide the image generation",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Negative text embeddings used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.batch_size = block_state.prompt_embeds.shape[0]
|
||||
block_state.dtype = block_state.prompt_embeds.dtype
|
||||
|
||||
_, seq_len, _ = block_state.prompt_embeds.shape
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
if block_state.negative_prompt_embeds is not None:
|
||||
_, seq_len, _ = block_state.negative_prompt_embeds.shape
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
|
||||
1, block_state.num_images_per_prompt, 1
|
||||
)
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2ProcessImagesInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
|
||||
@@ -12,22 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
Flux2PrepareGuidanceStep,
|
||||
Flux2PrepareImageLatentsStep,
|
||||
Flux2PrepareLatentsStep,
|
||||
Flux2RoPEInputsStep,
|
||||
Flux2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
|
||||
from .decoders import Flux2DecodeStep
|
||||
from .denoise import Flux2DenoiseStep
|
||||
from .encoders import (
|
||||
Flux2RemoteTextEncoderStep,
|
||||
@@ -47,6 +41,7 @@ Flux2VaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("preprocess", Flux2ProcessImagesInputStep()),
|
||||
("encode", Flux2VaeEncoderStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -77,56 +72,33 @@ class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
Flux2CoreDenoiseBlocks = InsertableDict(
|
||||
Flux2BeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2TextInputStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_guidance", Flux2PrepareGuidanceStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
class Flux2BeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = Flux2CoreDenoiseBlocks.values()
|
||||
block_names = Flux2CoreDenoiseBlocks.keys()
|
||||
block_classes = Flux2BeforeDenoiseBlocks.values()
|
||||
block_names = Flux2BeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core denoise step that performs the denoising process for Flux2-dev.\n"
|
||||
" - `Flux2TextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
|
||||
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
|
||||
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
|
||||
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
|
||||
" - `Flux2PrepareGuidanceStep` (prepare_guidance) prepares the guidance tensor for the denoising step.\n"
|
||||
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
|
||||
" - `Flux2DenoiseStep` (denoise) iteratively denoises the latents.\n"
|
||||
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents from the denoising step.",
|
||||
)
|
||||
]
|
||||
return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation."
|
||||
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2TextEncoderStep()),
|
||||
("vae_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("denoise", Flux2CoreDenoiseStep()),
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("before_denoise", Flux2BeforeDenoiseStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -135,8 +107,10 @@ AUTO_BLOCKS = InsertableDict(
|
||||
REMOTE_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2RemoteTextEncoderStep()),
|
||||
("vae_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("denoise", Flux2CoreDenoiseStep()),
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("before_denoise", Flux2BeforeDenoiseStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -156,16 +130,6 @@ class Flux2AutoBlocks(SequentialPipelineBlocks):
|
||||
"- For image-conditioned generation, you need to provide `image` (list of PIL images)."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="images",
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images from the decoding step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
@@ -173,10 +137,8 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_guidance", Flux2PrepareGuidanceStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -190,10 +152,8 @@ IMAGE_CONDITIONED_BLOCKS = InsertableDict(
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_guidance", Flux2PrepareGuidanceStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
@@ -1,232 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
Flux2KleinBaseRoPEInputsStep,
|
||||
Flux2PrepareImageLatentsStep,
|
||||
Flux2PrepareLatentsStep,
|
||||
Flux2RoPEInputsStep,
|
||||
Flux2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
|
||||
from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep
|
||||
from .encoders import (
|
||||
Flux2KleinBaseTextEncoderStep,
|
||||
Flux2KleinTextEncoderStep,
|
||||
Flux2VaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
Flux2KleinBaseTextInputStep,
|
||||
Flux2ProcessImagesInputStep,
|
||||
Flux2TextInputStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
################
|
||||
# VAE encoder
|
||||
################
|
||||
|
||||
Flux2KleinVaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("preprocess", Flux2ProcessImagesInputStep()),
|
||||
("encode", Flux2VaeEncoderStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = Flux2KleinVaeEncoderBlocks.values()
|
||||
block_names = Flux2KleinVaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations."
|
||||
|
||||
|
||||
class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [Flux2KleinVaeEncoderSequentialStep]
|
||||
block_names = ["img_conditioning"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"VAE encoder step that encodes the image inputs into their latent representations.\n"
|
||||
"This is an auto pipeline block that works for image conditioning tasks.\n"
|
||||
" - `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided.\n"
|
||||
" - If `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
###
|
||||
### Core denoise
|
||||
###
|
||||
|
||||
Flux2KleinCoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2TextInputStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2KleinDenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
|
||||
block_classes = Flux2KleinCoreDenoiseBlocks.values()
|
||||
block_names = Flux2KleinCoreDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core denoise step that performs the denoising process for Flux2-Klein (distilled model).\n"
|
||||
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
|
||||
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
|
||||
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
|
||||
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
|
||||
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
|
||||
" - `Flux2KleinDenoiseStep` (denoise) iteratively denoises the latents.\n"
|
||||
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents from the denoising step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
Flux2KleinBaseCoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2KleinBaseTextInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()),
|
||||
("denoise", Flux2KleinBaseDenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
block_classes = Flux2KleinBaseCoreDenoiseBlocks.values()
|
||||
block_names = Flux2KleinBaseCoreDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoise step that performs the denoising process for Flux2-Klein (base model)."
|
||||
return (
|
||||
"Core denoise step that performs the denoising process for Flux2-Klein (base model).\n"
|
||||
" - `Flux2KleinBaseTextInputStep` (input) standardizes the text inputs (prompt_embeds + negative_prompt_embeds) for the denoising step.\n"
|
||||
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
|
||||
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
|
||||
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
|
||||
" - `Flux2KleinBaseRoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids + negative_txt_ids) for the denoising step.\n"
|
||||
" - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n"
|
||||
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents from the denoising step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
###
|
||||
### Auto blocks
|
||||
###
|
||||
class Flux2KleinAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
block_classes = [
|
||||
Flux2KleinTextEncoderStep(),
|
||||
Flux2KleinAutoVaeEncoderStep(),
|
||||
Flux2KleinCoreDenoiseStep(),
|
||||
Flux2DecodeStep(),
|
||||
]
|
||||
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein.\n"
|
||||
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
|
||||
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="images",
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images from the decoding step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
block_classes = [
|
||||
Flux2KleinBaseTextEncoderStep(),
|
||||
Flux2KleinAutoVaeEncoderStep(),
|
||||
Flux2KleinBaseCoreDenoiseStep(),
|
||||
Flux2DecodeStep(),
|
||||
]
|
||||
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model).\n"
|
||||
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
|
||||
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="images",
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images from the decoding step.",
|
||||
)
|
||||
]
|
||||
@@ -13,8 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ...loaders import Flux2LoraLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
@@ -57,56 +55,3 @@ class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||
if getattr(self, "transformer", None):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
|
||||
|
||||
class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for Flux2-Klein.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Flux2KleinBaseAutoBlocks"
|
||||
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]:
|
||||
return "Flux2KleinAutoBlocks"
|
||||
else:
|
||||
return "Flux2KleinBaseAutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_sample_size(self):
|
||||
return 128
|
||||
|
||||
@property
|
||||
def vae_scale_factor(self):
|
||||
vae_scale_factor = 8
|
||||
if getattr(self, "vae", None) is not None:
|
||||
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def num_channels_latents(self):
|
||||
num_channels_latents = 32
|
||||
if getattr(self, "transformer", None):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
|
||||
return False
|
||||
|
||||
requires_unconditional_embeds = False
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
|
||||
|
||||
return requires_unconditional_embeds
|
||||
|
||||
@@ -59,7 +59,6 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
("flux", "FluxModularPipeline"),
|
||||
("flux-kontext", "FluxKontextModularPipeline"),
|
||||
("flux2", "Flux2ModularPipeline"),
|
||||
("flux2-klein", "Flux2KleinModularPipeline"),
|
||||
("qwenimage", "QwenImageModularPipeline"),
|
||||
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
||||
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
|
||||
|
||||
@@ -756,6 +756,7 @@ def load_sub_model(
|
||||
low_cpu_mem_usage: bool,
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
use_safetensors: bool,
|
||||
use_flashpack: bool,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]],
|
||||
provider_options: Any,
|
||||
disable_mmap: bool,
|
||||
@@ -838,6 +839,9 @@ def load_sub_model(
|
||||
loading_kwargs["variant"] = model_variants.pop(name, None)
|
||||
loading_kwargs["use_safetensors"] = use_safetensors
|
||||
|
||||
if is_diffusers_model:
|
||||
loading_kwargs["use_flashpack"] = use_flashpack
|
||||
|
||||
if from_flax:
|
||||
loading_kwargs["from_flax"] = True
|
||||
|
||||
|
||||
@@ -243,6 +243,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant: Optional[str] = None,
|
||||
max_shard_size: Optional[Union[int, str]] = None,
|
||||
push_to_hub: bool = False,
|
||||
use_flashpack: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -268,7 +269,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip
|
||||
install flashpack`.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
@@ -340,6 +343,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
||||
save_method_accept_variant = "variant" in save_method_signature.parameters
|
||||
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
|
||||
save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters
|
||||
|
||||
save_kwargs = {}
|
||||
if save_method_accept_safe:
|
||||
@@ -349,6 +353,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
if save_method_accept_max_shard_size and max_shard_size is not None:
|
||||
# max_shard_size is expected to not be None in ModelMixin
|
||||
save_kwargs["max_shard_size"] = max_shard_size
|
||||
if save_method_accept_flashpack:
|
||||
save_kwargs["use_flashpack"] = use_flashpack
|
||||
|
||||
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
||||
|
||||
@@ -707,6 +713,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
|
||||
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
|
||||
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
|
||||
flashpack`.
|
||||
use_onnx (`bool`, *optional*, defaults to `None`):
|
||||
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
|
||||
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
|
||||
@@ -772,6 +783,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant = kwargs.pop("variant", None)
|
||||
dduf_file = kwargs.pop("dduf_file", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_flashpack = kwargs.pop("use_flashpack", False)
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
@@ -1061,6 +1073,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cached_folder=cached_folder,
|
||||
use_safetensors=use_safetensors,
|
||||
use_flashpack=use_flashpack,
|
||||
dduf_entries=dduf_entries,
|
||||
provider_options=provider_options,
|
||||
disable_mmap=disable_mmap,
|
||||
|
||||
@@ -226,6 +226,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
time_shift_type: Literal["exponential"] = "exponential",
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
shift_terminal: Optional[float] = None,
|
||||
) -> None:
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
@@ -245,6 +246,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
||||
if shift_terminal is not None and not use_flow_sigmas:
|
||||
raise ValueError("`shift_terminal` is only supported when `use_flow_sigmas=True`.")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
@@ -313,8 +316,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(
|
||||
self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None
|
||||
) -> None:
|
||||
self,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
mu: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -323,13 +330,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
|
||||
automatically.
|
||||
mu (`float`, *optional*):
|
||||
Optional mu parameter for dynamic shifting when using exponential time shift type.
|
||||
"""
|
||||
if self.config.use_dynamic_shifting and mu is None:
|
||||
raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
|
||||
|
||||
if sigmas is not None:
|
||||
if not self.config.use_flow_sigmas:
|
||||
raise ValueError(
|
||||
"Passing `sigmas` is only supported when `use_flow_sigmas=True`. "
|
||||
"Please set `use_flow_sigmas=True` during scheduler initialization."
|
||||
)
|
||||
num_inference_steps = len(sigmas)
|
||||
|
||||
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
|
||||
if mu is not None:
|
||||
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
|
||||
self.config.flow_shift = np.exp(mu)
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = (
|
||||
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
@@ -354,8 +372,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
||||
)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.use_karras_sigmas:
|
||||
if sigmas is None:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
@@ -375,6 +394,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
elif self.config.use_exponential_sigmas:
|
||||
if sigmas is None:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
@@ -389,6 +410,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
elif self.config.use_beta_sigmas:
|
||||
if sigmas is None:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
@@ -403,9 +426,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
elif self.config.use_flow_sigmas:
|
||||
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
|
||||
sigmas = 1.0 - alphas
|
||||
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
|
||||
if sigmas is None:
|
||||
sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1]
|
||||
if self.config.use_dynamic_shifting:
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
else:
|
||||
sigmas = self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas)
|
||||
if self.config.shift_terminal:
|
||||
sigmas = self.stretch_shift_to_terminal(sigmas)
|
||||
eps = 1e-6
|
||||
if np.fabs(sigmas[0] - 1) < eps:
|
||||
# to avoid inf torch.log(alpha_si) in multistep_uni_p_bh_update during first/second update
|
||||
sigmas[0] -= eps
|
||||
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = sigmas[-1]
|
||||
@@ -417,6 +449,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
else:
|
||||
if sigmas is None:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
@@ -446,6 +480,43 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
|
||||
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
||||
if self.config.time_shift_type == "exponential":
|
||||
return self._time_shift_exponential(mu, sigma, t)
|
||||
elif self.config.time_shift_type == "linear":
|
||||
return self._time_shift_linear(mu, sigma, t)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.stretch_shift_to_terminal
|
||||
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
|
||||
value.
|
||||
|
||||
Reference:
|
||||
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
|
||||
|
||||
Args:
|
||||
t (`torch.Tensor`):
|
||||
A tensor of timesteps to be stretched and shifted.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
|
||||
"""
|
||||
one_minus_z = 1 - t
|
||||
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
|
||||
stretched_t = 1 - (one_minus_z / scale_factor)
|
||||
return stretched_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential
|
||||
def _time_shift_exponential(self, mu, sigma, t):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear
|
||||
def _time_shift_linear(self, mu, sigma, t):
|
||||
return mu / (mu + (1 / t - 1) ** sigma)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
@@ -17,51 +17,6 @@ class Flux2AutoBlocks(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinBaseAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2ModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
81
src/diffusers/utils/flashpack_utils.py
Normal file
81
src/diffusers/utils/flashpack_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from ..utils import _add_variant
|
||||
from .import_utils import is_flashpack_available
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_flashpack(
|
||||
model,
|
||||
save_directory: str,
|
||||
variant: Optional[str] = None,
|
||||
is_main_process: bool = True,
|
||||
):
|
||||
"""
|
||||
Save model weights in FlashPack format along with a metadata config.
|
||||
|
||||
Args:
|
||||
model: Diffusers model instance
|
||||
save_directory (`str`): Directory to save weights
|
||||
variant (`str`, *optional*): Model variant
|
||||
"""
|
||||
if not is_flashpack_available():
|
||||
raise ImportError(
|
||||
"The `use_flashpack=True` argument requires the `flashpack` package. "
|
||||
"Install it with `pip install flashpack`."
|
||||
)
|
||||
|
||||
from flashpack import pack_to_file
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
weights_name = _add_variant("model.flashpack", variant)
|
||||
weights_path = os.path.join(save_directory, weights_name)
|
||||
config_path = os.path.join(save_directory, "flashpack_config.json")
|
||||
|
||||
try:
|
||||
target_dtype = getattr(model, "dtype", None)
|
||||
logger.warning(f"Dtype used for FlashPack save: {target_dtype}")
|
||||
|
||||
# 1. Save binary weights
|
||||
pack_to_file(model, weights_path, target_dtype=target_dtype)
|
||||
|
||||
# 2. Save config metadata (best-effort)
|
||||
if hasattr(model, "config"):
|
||||
try:
|
||||
if hasattr(model.config, "to_dict"):
|
||||
config_data = model.config.to_dict()
|
||||
else:
|
||||
config_data = dict(model.config)
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f, indent=4)
|
||||
|
||||
except Exception as config_err:
|
||||
logger.warning(f"FlashPack weights saved, but config serialization failed: {config_err}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save weights in FlashPack format: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def load_flashpack(model, flashpack_file: str):
|
||||
"""
|
||||
Assign FlashPack weights from a file into an initialized PyTorch model.
|
||||
"""
|
||||
if not is_flashpack_available():
|
||||
raise ImportError("FlashPack weights require the `flashpack` package. Install with `pip install flashpack`.")
|
||||
|
||||
from flashpack import assign_from_file
|
||||
|
||||
logger.warning(f"Loading FlashPack weights from {flashpack_file}")
|
||||
|
||||
try:
|
||||
assign_from_file(model, flashpack_file)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load FlashPack weights from {flashpack_file}") from e
|
||||
@@ -231,6 +231,7 @@ _aiter_available, _aiter_version = _is_package_available("aiter")
|
||||
_kornia_available, _kornia_version = _is_package_available("kornia")
|
||||
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
|
||||
_av_available, _av_version = _is_package_available("av")
|
||||
_flashpack_available, _flashpack_version = _is_package_available("flashpack")
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
@@ -425,6 +426,10 @@ def is_av_available():
|
||||
return _av_available
|
||||
|
||||
|
||||
def is_flashpack_available():
|
||||
return _flashpack_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
@@ -942,6 +947,16 @@ def is_aiter_version(operation: str, version: str):
|
||||
return compare_versions(parse(_aiter_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_flashpack_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current flashpack version to a given reference with an operation.
|
||||
"""
|
||||
if not _flashpack_available:
|
||||
return False
|
||||
return compare_versions(parse(_flashpack_version), operation, version)
|
||||
|
||||
|
||||
def get_objects_from_module(module):
|
||||
"""
|
||||
Returns a dict of object names and values in a module, while skipping private/internal objects
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import pytest
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinModularPipeline,
|
||||
)
|
||||
|
||||
from ...testing_utils import floats_tensor, torch_device
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
|
||||
|
||||
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
|
||||
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
|
||||
inputs["image"] = init_image
|
||||
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
@pytest.mark.skip(reason="batched inference is currently not supported")
|
||||
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
|
||||
return
|
||||
@@ -1,91 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import pytest
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinModularPipeline,
|
||||
)
|
||||
|
||||
from ...testing_utils import floats_tensor, torch_device
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
|
||||
|
||||
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
|
||||
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2KleinModularPipeline
|
||||
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
|
||||
inputs["image"] = init_image
|
||||
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
@pytest.mark.skip(reason="batched inference is currently not supported")
|
||||
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
|
||||
return
|
||||
Reference in New Issue
Block a user