Compare commits

..

9 Commits

Author SHA1 Message Date
sayakpaul
ff26d9ffd5 up 2026-01-22 17:17:00 +05:30
sayakpaul
668f265054 up 2026-01-22 17:17:00 +05:30
sayakpaul
55eaa6efb2 style 2026-01-22 17:17:00 +05:30
Sayak Paul
b603429ff5 Merge branch 'main' into fal-flashpack 2026-01-22 17:14:14 +05:30
Aryan V S
7a02fadad3 [scheduler] Support custom sigmas in UniPCMultistepScheduler (#12109)
* update

* fix tests

* Apply suggestions from code review

* Revert default flow sigmas change so that tests relying on UniPC multistep still pass

* Remove custom timesteps for UniPC multistep set_timesteps

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Daniel Gu <dgu8957@gmail.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-01-21 17:18:59 -08:00
“devanshi00”
3bc3fdb035 redundant model initialisation removed final 2026-01-21 12:31:43 +05:30
“devanshi00”
8cc38a75d3 redundant model initialisation removed 2026-01-21 12:27:42 +05:30
“devanshi00”
e5bb10cfe1 review comments resolved 2026-01-21 04:22:50 +05:30
“devanshi00”
ec541906c5 added fal-flashpack support 2026-01-19 14:52:15 +05:30
11 changed files with 320 additions and 245 deletions

View File

@@ -675,6 +675,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
variant: Optional[str] = None,
max_shard_size: Union[int, str] = "10GB",
push_to_hub: bool = False,
use_flashpack: bool = False,
**kwargs,
):
"""
@@ -707,6 +708,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
use_flashpack (`bool`, *optional*, defaults to `False`):
Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. FlashPack is a
binary format that allows for faster loading. Requires the `flashpack` library to be installed.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
@@ -727,12 +731,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" the logger on the traceback to understand the reason why the quantized model is not serializable."
)
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
os.makedirs(save_directory, exist_ok=True)
if push_to_hub:
@@ -746,67 +744,80 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Only save the model itself if we are using distributed training
model_to_save = self
# Attach architecture to the config
# Save the config
if is_main_process:
model_to_save.save_config(save_directory)
# Save the model
state_dict = model_to_save.state_dict()
if use_flashpack:
if not is_main_process:
return
# Save the model
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
from ..utils.flashpack_utils import save_flashpack
# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
save_flashpack(model_to_save, save_directory, variant=variant)
else:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict = model_to_save.state_dict()
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)
# Save each shard
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)
# Save index file if sharded
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
# Push to hub if requested (common to both paths)
if push_to_hub:
# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token)
@@ -939,6 +950,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack)
weights if a compatible `.flashpack` file is found. If flashpack is unavailable or the `.flashpack`
file cannot be used, automatic fallback to the standard loading path (for example, `safetensors`).
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
@@ -982,6 +997,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_flashpack = kwargs.pop("use_flashpack", False)
quantization_config = kwargs.pop("quantization_config", None)
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
@@ -1199,7 +1215,31 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
else:
flashpack_file = None
if use_flashpack:
try:
flashpack_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant("model.flashpack", variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
except EnvironmentError:
flashpack_file = None
logger.warning(
"`use_flashpack` was specified to be True but not flashpack file was found. Resorting to non-flashpack alternatives."
)
if flashpack_file is None:
# in the case it is sharded, we have already the index
if is_sharded:
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
@@ -1215,6 +1255,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dduf_entries=dduf_entries,
)
elif use_safetensors:
logger.warning("Trying to load model weights with safetensors format.")
try:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
@@ -1280,6 +1321,29 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
if flashpack_file is not None:
from ..utils.flashpack_utils import load_flashpack
# Even when using FlashPack, we preserve `low_cpu_mem_usage` behavior by initializing
# the model with meta tensors. Since FlashPack cannot write into meta tensors, we
# explicitly materialize parameters before loading to ensure correctness and parity
# with the standard loading path.
if any(p.device.type == "meta" for p in model.parameters()):
model.to_empty(device="cpu")
load_flashpack(model, flashpack_file)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
model.eval()
if output_loading_info:
return model, {
"missing_keys": [],
"unexpected_keys": [],
"mismatched_keys": [],
"error_msgs": [],
}
return model
state_dict = None
if not is_sharded:
# Time to load the checkpoint
@@ -1327,7 +1391,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
keep_in_fp32_modules=keep_in_fp32_modules,
dduf_entries=dduf_entries,
is_parallel_loading_enabled=is_parallel_loading_enabled,
disable_mmap=disable_mmap,
)
loading_info = {
"missing_keys": missing_keys,
@@ -1373,6 +1436,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if output_loading_info:
return model, loading_info
logger.warning(f"Model till end {pretrained_model_name_or_path} loaded successfully")
return model
# Adapted from `transformers`.

View File

@@ -756,6 +756,7 @@ def load_sub_model(
low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike],
use_safetensors: bool,
use_flashpack: bool,
dduf_entries: Optional[Dict[str, DDUFEntry]],
provider_options: Any,
disable_mmap: bool,
@@ -838,6 +839,9 @@ def load_sub_model(
loading_kwargs["variant"] = model_variants.pop(name, None)
loading_kwargs["use_safetensors"] = use_safetensors
if is_diffusers_model:
loading_kwargs["use_flashpack"] = use_flashpack
if from_flax:
loading_kwargs["from_flax"] = True

View File

@@ -243,6 +243,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant: Optional[str] = None,
max_shard_size: Optional[Union[int, str]] = None,
push_to_hub: bool = False,
use_flashpack: bool = False,
**kwargs,
):
"""
@@ -268,7 +269,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
use_flashpack (`bool`, *optional*, defaults to `False`):
Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip
install flashpack`.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
@@ -340,6 +343,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
save_method_accept_variant = "variant" in save_method_signature.parameters
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters
save_kwargs = {}
if save_method_accept_safe:
@@ -349,6 +353,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if save_method_accept_max_shard_size and max_shard_size is not None:
# max_shard_size is expected to not be None in ModelMixin
save_kwargs["max_shard_size"] = max_shard_size
if save_method_accept_flashpack:
save_kwargs["use_flashpack"] = use_flashpack
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
@@ -707,6 +713,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
flashpack`.
use_onnx (`bool`, *optional*, defaults to `None`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
@@ -772,6 +783,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant = kwargs.pop("variant", None)
dduf_file = kwargs.pop("dduf_file", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_flashpack = kwargs.pop("use_flashpack", False)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
quantization_config = kwargs.pop("quantization_config", None)
@@ -1061,6 +1073,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder,
use_safetensors=use_safetensors,
use_flashpack=use_flashpack,
dduf_entries=dduf_entries,
provider_options=provider_options,
disable_mmap=disable_mmap,

View File

@@ -226,6 +226,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
time_shift_type: Literal["exponential"] = "exponential",
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
shift_terminal: Optional[float] = None,
) -> None:
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -245,6 +246,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
if shift_terminal is not None and not use_flow_sigmas:
raise ValueError("`shift_terminal` is only supported when `use_flow_sigmas=True`.")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
@@ -313,8 +316,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = begin_index
def set_timesteps(
self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None
) -> None:
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[float] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -323,13 +330,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
sigmas (`List[float]`, *optional*):
Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
automatically.
mu (`float`, *optional*):
Optional mu parameter for dynamic shifting when using exponential time shift type.
"""
if self.config.use_dynamic_shifting and mu is None:
raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
if sigmas is not None:
if not self.config.use_flow_sigmas:
raise ValueError(
"Passing `sigmas` is only supported when `use_flow_sigmas=True`. "
"Please set `use_flow_sigmas=True` during scheduler initialization."
)
num_inference_steps = len(sigmas)
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if mu is not None:
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
self.config.flow_shift = np.exp(mu)
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
@@ -354,8 +372,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas:
if sigmas is None:
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
@@ -375,6 +394,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
elif self.config.use_exponential_sigmas:
if sigmas is None:
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
@@ -389,6 +410,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
elif self.config.use_beta_sigmas:
if sigmas is None:
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
@@ -403,9 +426,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
elif self.config.use_flow_sigmas:
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
if sigmas is None:
sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1]
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
sigmas = self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas)
if self.config.shift_terminal:
sigmas = self.stretch_shift_to_terminal(sigmas)
eps = 1e-6
if np.fabs(sigmas[0] - 1) < eps:
# to avoid inf torch.log(alpha_si) in multistep_uni_p_bh_update during first/second update
sigmas[0] -= eps
timesteps = (sigmas * self.config.num_train_timesteps).copy()
if self.config.final_sigmas_type == "sigma_min":
sigma_last = sigmas[-1]
@@ -417,6 +449,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
else:
if sigmas is None:
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -446,6 +480,43 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
if self.config.time_shift_type == "exponential":
return self._time_shift_exponential(mu, sigma, t)
elif self.config.time_shift_type == "linear":
return self._time_shift_linear(mu, sigma, t)
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.stretch_shift_to_terminal
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
r"""
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
value.
Reference:
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
Args:
t (`torch.Tensor`):
A tensor of timesteps to be stretched and shifted.
Returns:
`torch.Tensor`:
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
"""
one_minus_z = 1 - t
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
stretched_t = 1 - (one_minus_z / scale_factor)
return stretched_t
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential
def _time_shift_exponential(self, mu, sigma, t):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear
def _time_shift_linear(self, mu, sigma, t):
return mu / (mu + (1 / t - 1) ** sigma)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""

View File

@@ -0,0 +1,81 @@
import json
import os
from typing import Optional
from ..utils import _add_variant
from .import_utils import is_flashpack_available
from .logging import get_logger
logger = get_logger(__name__)
def save_flashpack(
model,
save_directory: str,
variant: Optional[str] = None,
is_main_process: bool = True,
):
"""
Save model weights in FlashPack format along with a metadata config.
Args:
model: Diffusers model instance
save_directory (`str`): Directory to save weights
variant (`str`, *optional*): Model variant
"""
if not is_flashpack_available():
raise ImportError(
"The `use_flashpack=True` argument requires the `flashpack` package. "
"Install it with `pip install flashpack`."
)
from flashpack import pack_to_file
os.makedirs(save_directory, exist_ok=True)
weights_name = _add_variant("model.flashpack", variant)
weights_path = os.path.join(save_directory, weights_name)
config_path = os.path.join(save_directory, "flashpack_config.json")
try:
target_dtype = getattr(model, "dtype", None)
logger.warning(f"Dtype used for FlashPack save: {target_dtype}")
# 1. Save binary weights
pack_to_file(model, weights_path, target_dtype=target_dtype)
# 2. Save config metadata (best-effort)
if hasattr(model, "config"):
try:
if hasattr(model.config, "to_dict"):
config_data = model.config.to_dict()
else:
config_data = dict(model.config)
with open(config_path, "w") as f:
json.dump(config_data, f, indent=4)
except Exception as config_err:
logger.warning(f"FlashPack weights saved, but config serialization failed: {config_err}")
except Exception as e:
logger.error(f"Failed to save weights in FlashPack format: {e}")
raise
def load_flashpack(model, flashpack_file: str):
"""
Assign FlashPack weights from a file into an initialized PyTorch model.
"""
if not is_flashpack_available():
raise ImportError("FlashPack weights require the `flashpack` package. Install with `pip install flashpack`.")
from flashpack import assign_from_file
logger.warning(f"Loading FlashPack weights from {flashpack_file}")
try:
assign_from_file(model, flashpack_file)
except Exception as e:
raise RuntimeError(f"Failed to load FlashPack weights from {flashpack_file}") from e

View File

@@ -231,6 +231,7 @@ _aiter_available, _aiter_version = _is_package_available("aiter")
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
_av_available, _av_version = _is_package_available("av")
_flashpack_available, _flashpack_version = _is_package_available("flashpack")
def is_torch_available():
@@ -425,6 +426,10 @@ def is_av_available():
return _av_available
def is_flashpack_available():
return _flashpack_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -942,6 +947,16 @@ def is_aiter_version(operation: str, version: str):
return compare_versions(parse(_aiter_version), operation, version)
@cache
def is_flashpack_version(operation: str, version: str):
"""
Compares the current flashpack version to a given reference with an operation.
"""
if not _flashpack_available:
return False
return compare_versions(parse(_flashpack_version), operation, version)
def get_objects_from_module(module):
"""
Returns a dict of object names and values in a module, while skipping private/internal objects

View File

@@ -37,14 +37,9 @@ class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxModularPipeline
pipeline_blocks_class = FluxAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
default_repo_id = "black-forest-labs/FLUX.1-dev"
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "max_sequence_length"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = None # None if vae_encoder is not supported
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
@@ -68,21 +63,10 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxModularPipeline
pipeline_blocks_class = FluxAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
default_repo_id = "black-forest-labs/FLUX.1-dev"
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(
[
"prompt",
"max_sequence_length",
]
)
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = frozenset(["image", "height", "width"])
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = super().get_pipeline(components_manager, torch_dtype)
@@ -145,13 +129,9 @@ class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxKontextModularPipeline
pipeline_blocks_class = FluxKontextAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe"
default_repo_id = "black-forest-labs/FLUX.1-kontext-dev"
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"])
text_encoder_block_params = frozenset(["prompt", "max_sequence_length"])
decode_block_params = frozenset(["latents"])
vae_encoder_block_params = frozenset(["image", "height", "width"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)

View File

@@ -32,14 +32,9 @@ class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2ModularPipeline
pipeline_blocks_class = Flux2AutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
default_repo_id = "black-forest-labs/FLUX.2-dev"
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = None
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
@@ -68,10 +63,6 @@ class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = frozenset(["image", "height", "width"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)

View File

@@ -34,16 +34,10 @@ class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuider
pipeline_class = QwenImageModularPipeline
pipeline_blocks_class = QwenImageAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-modular"
default_repo_id = "Qwen/Qwen-Image"
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "negative_prompt", "max_sequence_length"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = None # None if vae_encoder is not supported
def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
@@ -66,16 +60,10 @@ class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGu
pipeline_class = QwenImageEditModularPipeline
pipeline_blocks_class = QwenImageEditAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-modular"
default_repo_id = "Qwen/Qwen-Image-Edit"
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "negative_prompt", "max_sequence_length"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = frozenset(["image", "height", "width"])
def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
@@ -98,7 +86,6 @@ class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, Modul
pipeline_class = QwenImageEditPlusModularPipeline
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
default_repo_id = "Qwen/Qwen-Image-Edit-2509"
# No `mask_image` yet.
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])

View File

@@ -279,8 +279,6 @@ class TestSDXLModularPipelineFast(
pipeline_class = StableDiffusionXLModularPipeline
pipeline_blocks_class = StableDiffusionXLAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
default_repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
params = frozenset(
[
"prompt",
@@ -293,11 +291,6 @@ class TestSDXLModularPipelineFast(
batch_params = frozenset(["prompt", "negative_prompt"])
expected_image_output_shape = (1, 3, 64, 64)
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = None # None if vae_encoder is not supported
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
@@ -346,11 +339,6 @@ class TestSDXLImg2ImgModularPipelineFast(
batch_params = frozenset(["prompt", "negative_prompt", "image"])
expected_image_output_shape = (1, 3, 64, 64)
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = frozenset(["image"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {

View File

@@ -48,12 +48,6 @@ class ModularPipelineTesterMixin:
"You need to set the attribute `pretrained_model_name_or_path` in the child test class. See existing pipeline tests for reference."
)
@property
def default_repo_id(self) -> str:
raise NotImplementedError(
"You need to set the attribute `default_repo_id` in the child test class. See existing pipeline tests for reference."
)
@property
def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
raise NotImplementedError(
@@ -96,30 +90,6 @@ class ModularPipelineTesterMixin:
"See existing pipeline tests for reference."
)
def text_encoder_block_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `text_encoder_block_params` in the child test class. "
"`text_encoder_block_params` are the parameters required to be passed to the text encoder block. "
" if should be a subset of the parameters returned by `get_dummy_inputs`"
"See existing pipeline tests for reference."
)
def decode_block_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `decode_block_params` in the child test class. "
"`decode_block_params` are the parameters required to be passed to the decode block. "
" if should be a subset of the parameters returned by `get_dummy_inputs`"
"See existing pipeline tests for reference."
)
def vae_encoder_block_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `vae_encoder_block_params` in the child test class. "
"`vae_encoder_block_params` are the parameters required to be passed to the vae encoder block. "
" if should be a subset of the parameters returned by `get_dummy_inputs`"
"See existing pipeline tests for reference."
)
def setup_method(self):
# clean up the VRAM before each test
torch.compiler.reset()
@@ -154,96 +124,6 @@ class ModularPipelineTesterMixin:
_check_for_parameters(self.params, input_parameters, "input")
_check_for_parameters(self.optional_params, optional_parameters, "optional")
def test_loading_from_default_repo(self):
if self.default_repo_id is None:
return
try:
pipe = ModularPipeline.from_pretrained(self.default_repo_id)
assert pipe.blocks.__class__ == self.pipeline_blocks_class
except Exception as e:
assert False, f"Failed to load pipeline from default repo: {e}"
def test_modular_inference(self):
# run the pipeline to get the base output for comparison
pipe = self.get_pipeline()
pipe.to(torch_device, torch.float32)
inputs = self.get_dummy_inputs()
standard_output = pipe(**inputs, output="images")
# create text, denoise, decoder (and optional vae encoder) nodes
blocks = self.pipeline_blocks_class()
assert "text_encoder" in blocks.sub_blocks, "`text_encoder` block is not present in the pipeline"
assert "denoise" in blocks.sub_blocks, "`denoise` block is not present in the pipeline"
assert "decode" in blocks.sub_blocks, "`decode` block is not present in the pipeline"
if self.vae_encoder_block_params is not None:
assert "vae_encoder" in blocks.sub_blocks, "`vae_encoder` block is not present in the pipeline"
# manually set the components in the sub_pipe
# a hack to workaround the fact the default pipeline properties are often incorrect for testing cases,
# #e.g. vae_scale_factor is ususally not 8 because vae is configured to be smaller for testing
def manually_set_all_components(pipe: ModularPipeline, sub_pipe: ModularPipeline):
for n, comp in pipe.components.items():
if not hasattr(sub_pipe, n):
setattr(sub_pipe, n, comp)
text_node = blocks.sub_blocks["text_encoder"].init_pipeline(self.pretrained_model_name_or_path)
text_node.load_components(torch_dtype=torch.float32)
text_node.to(torch_device)
manually_set_all_components(pipe, text_node)
denoise_node = blocks.sub_blocks["denoise"].init_pipeline(self.pretrained_model_name_or_path)
denoise_node.load_components(torch_dtype=torch.float32)
denoise_node.to(torch_device)
manually_set_all_components(pipe, denoise_node)
decoder_node = blocks.sub_blocks["decode"].init_pipeline(self.pretrained_model_name_or_path)
decoder_node.load_components(torch_dtype=torch.float32)
decoder_node.to(torch_device)
manually_set_all_components(pipe, decoder_node)
if self.vae_encoder_block_params is not None:
vae_encoder_node = blocks.sub_blocks["vae_encoder"].init_pipeline(self.pretrained_model_name_or_path)
vae_encoder_node.load_components(torch_dtype=torch.float32)
vae_encoder_node.to(torch_device)
manually_set_all_components(pipe, vae_encoder_node)
else:
vae_encoder_node = None
# prepare inputs for each node
inputs = self.get_dummy_inputs()
def get_block_inputs(inputs: dict, block_params: frozenset) -> tuple[dict, dict]:
block_inputs = {}
for name in block_params:
if name in inputs:
block_inputs[name] = inputs.pop(name)
return block_inputs, inputs
text_inputs, inputs = get_block_inputs(inputs, self.text_encoder_block_params)
decoder_inputs, inputs = get_block_inputs(inputs, self.decode_block_params)
if vae_encoder_node is not None:
vae_encoder_inputs, inputs = get_block_inputs(inputs, self.vae_encoder_block_params)
# this is also to make sure pipelines mark text outputs as denoiser_input_fields
text_output = text_node(**text_inputs).get_by_kwargs("denoiser_input_fields")
if vae_encoder_node is not None:
vae_encoder_output = vae_encoder_node(**vae_encoder_inputs).values
denoise_inputs = {**text_output, **vae_encoder_output, **inputs}
else:
denoise_inputs = {**text_output, **inputs}
# denoise node output should be "latents"
latents = denoise_node(**denoise_inputs).latents
# denoder node input should be "latents" and output should be "images"
modular_output = decoder_node(**decoder_inputs, latents=latents).images
assert modular_output.shape == standard_output.shape, (
f"Modular output should have same shape as standard output {standard_output.shape}, but got {modular_output.shape}"
)
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
pipe = self.get_pipeline().to(torch_device)