mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-24 04:25:56 +08:00
Compare commits
2 Commits
devanshi00
...
modular-up
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d42a97a40 | ||
|
|
39a6a0c171 |
@@ -675,7 +675,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
variant: Optional[str] = None,
|
||||
max_shard_size: Union[int, str] = "10GB",
|
||||
push_to_hub: bool = False,
|
||||
use_flashpack: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -708,9 +707,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. FlashPack is a
|
||||
binary format that allows for faster loading. Requires the `flashpack` library to be installed.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
@@ -731,6 +727,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
" the logger on the traceback to understand the reason why the quantized model is not serializable."
|
||||
)
|
||||
|
||||
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
weights_name = _add_variant(weights_name, variant)
|
||||
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if push_to_hub:
|
||||
@@ -744,80 +746,67 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# Only save the model itself if we are using distributed training
|
||||
model_to_save = self
|
||||
|
||||
# Attach architecture to the config
|
||||
# Save the config
|
||||
if is_main_process:
|
||||
model_to_save.save_config(save_directory)
|
||||
|
||||
if use_flashpack:
|
||||
if not is_main_process:
|
||||
return
|
||||
# Save the model
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
from ..utils.flashpack_utils import save_flashpack
|
||||
# Save the model
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
||||
)
|
||||
|
||||
save_flashpack(model_to_save, save_directory, variant=variant)
|
||||
else:
|
||||
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
weights_name = _add_variant(weights_name, variant)
|
||||
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
# Clean the folder from a previous save
|
||||
if is_main_process:
|
||||
for filename in os.listdir(save_directory):
|
||||
if filename in state_dict_split.filename_to_tensors.keys():
|
||||
continue
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
if not os.path.isfile(full_filename):
|
||||
continue
|
||||
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
||||
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
||||
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
||||
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
||||
if (
|
||||
filename.startswith(weights_without_ext)
|
||||
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
||||
):
|
||||
os.remove(full_filename)
|
||||
|
||||
state_dict = model_to_save.state_dict()
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
||||
)
|
||||
|
||||
# Clean the folder from a previous save
|
||||
if is_main_process:
|
||||
for filename in os.listdir(save_directory):
|
||||
if filename in state_dict_split.filename_to_tensors.keys():
|
||||
continue
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
if not os.path.isfile(full_filename):
|
||||
continue
|
||||
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
||||
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
||||
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
||||
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
||||
if (
|
||||
filename.startswith(weights_without_ext)
|
||||
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
||||
):
|
||||
os.remove(full_filename)
|
||||
|
||||
# Save each shard
|
||||
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||
filepath = os.path.join(save_directory, filename)
|
||||
if safe_serialization:
|
||||
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
||||
# joyfulness), but for now this enough.
|
||||
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(shard, filepath)
|
||||
|
||||
# Save index file if sharded
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
logger.info(
|
||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||
filepath = os.path.join(save_directory, filename)
|
||||
if safe_serialization:
|
||||
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
||||
# joyfulness), but for now this enough.
|
||||
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
||||
else:
|
||||
path_to_weights = os.path.join(save_directory, weights_name)
|
||||
logger.info(f"Model weights saved in {path_to_weights}")
|
||||
torch.save(shard, filepath)
|
||||
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
logger.info(
|
||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
else:
|
||||
path_to_weights = os.path.join(save_directory, weights_name)
|
||||
logger.info(f"Model weights saved in {path_to_weights}")
|
||||
|
||||
# Push to hub if requested (common to both paths)
|
||||
if push_to_hub:
|
||||
# Create a new empty model card and eventually tag it
|
||||
model_card = load_or_create_model_card(repo_id, token=token)
|
||||
@@ -950,10 +939,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
||||
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
||||
weights. If set to `False`, `safetensors` weights are not loaded.
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack)
|
||||
weights if a compatible `.flashpack` file is found. If flashpack is unavailable or the `.flashpack`
|
||||
file cannot be used, automatic fallback to the standard loading path (for example, `safetensors`).
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
@@ -997,7 +982,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_flashpack = kwargs.pop("use_flashpack", False)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
@@ -1215,31 +1199,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
||||
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
|
||||
|
||||
flashpack_file = None
|
||||
if use_flashpack:
|
||||
try:
|
||||
flashpack_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant("model.flashpack", variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
except EnvironmentError:
|
||||
flashpack_file = None
|
||||
logger.warning(
|
||||
"`use_flashpack` was specified to be True but not flashpack file was found. Resorting to non-flashpack alternatives."
|
||||
)
|
||||
|
||||
if flashpack_file is None:
|
||||
else:
|
||||
# in the case it is sharded, we have already the index
|
||||
if is_sharded:
|
||||
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
|
||||
@@ -1255,7 +1215,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
elif use_safetensors:
|
||||
logger.warning("Trying to load model weights with safetensors format.")
|
||||
try:
|
||||
resolved_model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
@@ -1321,29 +1280,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
if dtype_orig is not None:
|
||||
torch.set_default_dtype(dtype_orig)
|
||||
|
||||
if flashpack_file is not None:
|
||||
from ..utils.flashpack_utils import load_flashpack
|
||||
|
||||
# Even when using FlashPack, we preserve `low_cpu_mem_usage` behavior by initializing
|
||||
# the model with meta tensors. Since FlashPack cannot write into meta tensors, we
|
||||
# explicitly materialize parameters before loading to ensure correctness and parity
|
||||
# with the standard loading path.
|
||||
if any(p.device.type == "meta" for p in model.parameters()):
|
||||
model.to_empty(device="cpu")
|
||||
load_flashpack(model, flashpack_file)
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
model.eval()
|
||||
|
||||
if output_loading_info:
|
||||
return model, {
|
||||
"missing_keys": [],
|
||||
"unexpected_keys": [],
|
||||
"mismatched_keys": [],
|
||||
"error_msgs": [],
|
||||
}
|
||||
|
||||
return model
|
||||
|
||||
state_dict = None
|
||||
if not is_sharded:
|
||||
# Time to load the checkpoint
|
||||
@@ -1391,6 +1327,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
dduf_entries=dduf_entries,
|
||||
is_parallel_loading_enabled=is_parallel_loading_enabled,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
@@ -1436,8 +1373,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
if output_loading_info:
|
||||
return model, loading_info
|
||||
|
||||
logger.warning(f"Model till end {pretrained_model_name_or_path} loaded successfully")
|
||||
|
||||
return model
|
||||
|
||||
# Adapted from `transformers`.
|
||||
|
||||
@@ -1552,11 +1552,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
else:
|
||||
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
|
||||
|
||||
self.blocks = blocks
|
||||
self._blocks = blocks
|
||||
self._components_manager = components_manager
|
||||
self._collection = collection
|
||||
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
|
||||
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
|
||||
self._component_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_components}
|
||||
self._config_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_configs}
|
||||
|
||||
# update component_specs and config_specs based on modular_model_index.json
|
||||
if modular_config_dict is not None:
|
||||
@@ -1603,7 +1603,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
for name, config_spec in self._config_specs.items():
|
||||
default_configs[name] = config_spec.default
|
||||
self.register_to_config(**default_configs)
|
||||
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
|
||||
self.register_to_config(
|
||||
_blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None
|
||||
)
|
||||
|
||||
@property
|
||||
def default_call_parameters(self) -> Dict[str, Any]:
|
||||
@@ -1612,7 +1614,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- Dictionary mapping input names to their default values
|
||||
"""
|
||||
params = {}
|
||||
for input_param in self.blocks.inputs:
|
||||
for input_param in self._blocks.inputs:
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
@@ -1775,7 +1777,15 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
Returns:
|
||||
- The docstring of the pipeline blocks
|
||||
"""
|
||||
return self.blocks.doc
|
||||
return self._blocks.doc
|
||||
|
||||
@property
|
||||
def blocks(self) -> ModularPipelineBlocks:
|
||||
"""
|
||||
Returns:
|
||||
- A copy of the pipeline blocks
|
||||
"""
|
||||
return deepcopy(self._blocks)
|
||||
|
||||
def register_components(self, **kwargs):
|
||||
"""
|
||||
@@ -2509,7 +2519,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
|
||||
for sub_block_name, sub_block in self._blocks.sub_blocks.items():
|
||||
if hasattr(sub_block, "set_progress_bar_config"):
|
||||
sub_block.set_progress_bar_config(**kwargs)
|
||||
|
||||
@@ -2563,7 +2573,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
for expected_input_param in self.blocks.inputs:
|
||||
for expected_input_param in self._blocks.inputs:
|
||||
name = expected_input_param.name
|
||||
default = expected_input_param.default
|
||||
kwargs_type = expected_input_param.kwargs_type
|
||||
@@ -2582,9 +2592,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
# Run the pipeline
|
||||
with torch.no_grad():
|
||||
try:
|
||||
_, state = self.blocks(self, state)
|
||||
_, state = self._blocks(self, state)
|
||||
except Exception:
|
||||
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
|
||||
error_msg = f"Error in block: ({self._blocks.__class__.__name__}):\n"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
|
||||
@@ -756,7 +756,6 @@ def load_sub_model(
|
||||
low_cpu_mem_usage: bool,
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
use_safetensors: bool,
|
||||
use_flashpack: bool,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]],
|
||||
provider_options: Any,
|
||||
disable_mmap: bool,
|
||||
@@ -839,9 +838,6 @@ def load_sub_model(
|
||||
loading_kwargs["variant"] = model_variants.pop(name, None)
|
||||
loading_kwargs["use_safetensors"] = use_safetensors
|
||||
|
||||
if is_diffusers_model:
|
||||
loading_kwargs["use_flashpack"] = use_flashpack
|
||||
|
||||
if from_flax:
|
||||
loading_kwargs["from_flax"] = True
|
||||
|
||||
|
||||
@@ -243,7 +243,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant: Optional[str] = None,
|
||||
max_shard_size: Optional[Union[int, str]] = None,
|
||||
push_to_hub: bool = False,
|
||||
use_flashpack: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -269,9 +268,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip
|
||||
install flashpack`.
|
||||
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
@@ -343,7 +340,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
||||
save_method_accept_variant = "variant" in save_method_signature.parameters
|
||||
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
|
||||
save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters
|
||||
|
||||
save_kwargs = {}
|
||||
if save_method_accept_safe:
|
||||
@@ -353,8 +349,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
if save_method_accept_max_shard_size and max_shard_size is not None:
|
||||
# max_shard_size is expected to not be None in ModelMixin
|
||||
save_kwargs["max_shard_size"] = max_shard_size
|
||||
if save_method_accept_flashpack:
|
||||
save_kwargs["use_flashpack"] = use_flashpack
|
||||
|
||||
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
||||
|
||||
@@ -713,11 +707,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
|
||||
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
|
||||
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
|
||||
flashpack`.
|
||||
use_onnx (`bool`, *optional*, defaults to `None`):
|
||||
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
|
||||
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
|
||||
@@ -783,7 +772,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant = kwargs.pop("variant", None)
|
||||
dduf_file = kwargs.pop("dduf_file", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_flashpack = kwargs.pop("use_flashpack", False)
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
@@ -1073,7 +1061,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cached_folder=cached_folder,
|
||||
use_safetensors=use_safetensors,
|
||||
use_flashpack=use_flashpack,
|
||||
dduf_entries=dduf_entries,
|
||||
provider_options=provider_options,
|
||||
disable_mmap=disable_mmap,
|
||||
|
||||
@@ -226,7 +226,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
time_shift_type: Literal["exponential"] = "exponential",
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
shift_terminal: Optional[float] = None,
|
||||
) -> None:
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
@@ -246,8 +245,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
||||
if shift_terminal is not None and not use_flow_sigmas:
|
||||
raise ValueError("`shift_terminal` is only supported when `use_flow_sigmas=True`.")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
@@ -316,12 +313,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
mu: Optional[float] = None,
|
||||
):
|
||||
self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None
|
||||
) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -330,24 +323,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
|
||||
automatically.
|
||||
mu (`float`, *optional*):
|
||||
Optional mu parameter for dynamic shifting when using exponential time shift type.
|
||||
"""
|
||||
if self.config.use_dynamic_shifting and mu is None:
|
||||
raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
|
||||
|
||||
if sigmas is not None:
|
||||
if not self.config.use_flow_sigmas:
|
||||
raise ValueError(
|
||||
"Passing `sigmas` is only supported when `use_flow_sigmas=True`. "
|
||||
"Please set `use_flow_sigmas=True` during scheduler initialization."
|
||||
)
|
||||
num_inference_steps = len(sigmas)
|
||||
|
||||
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
|
||||
if mu is not None:
|
||||
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
|
||||
self.config.flow_shift = np.exp(mu)
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = (
|
||||
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
@@ -372,9 +354,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
||||
)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.use_karras_sigmas:
|
||||
if sigmas is None:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
@@ -394,8 +375,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
elif self.config.use_exponential_sigmas:
|
||||
if sigmas is None:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
@@ -410,8 +389,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
elif self.config.use_beta_sigmas:
|
||||
if sigmas is None:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
@@ -426,18 +403,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
elif self.config.use_flow_sigmas:
|
||||
if sigmas is None:
|
||||
sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1]
|
||||
if self.config.use_dynamic_shifting:
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
else:
|
||||
sigmas = self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas)
|
||||
if self.config.shift_terminal:
|
||||
sigmas = self.stretch_shift_to_terminal(sigmas)
|
||||
eps = 1e-6
|
||||
if np.fabs(sigmas[0] - 1) < eps:
|
||||
# to avoid inf torch.log(alpha_si) in multistep_uni_p_bh_update during first/second update
|
||||
sigmas[0] -= eps
|
||||
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
|
||||
sigmas = 1.0 - alphas
|
||||
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
|
||||
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = sigmas[-1]
|
||||
@@ -449,8 +417,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
else:
|
||||
if sigmas is None:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
@@ -480,43 +446,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
|
||||
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
||||
if self.config.time_shift_type == "exponential":
|
||||
return self._time_shift_exponential(mu, sigma, t)
|
||||
elif self.config.time_shift_type == "linear":
|
||||
return self._time_shift_linear(mu, sigma, t)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.stretch_shift_to_terminal
|
||||
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
|
||||
value.
|
||||
|
||||
Reference:
|
||||
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
|
||||
|
||||
Args:
|
||||
t (`torch.Tensor`):
|
||||
A tensor of timesteps to be stretched and shifted.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
|
||||
"""
|
||||
one_minus_z = 1 - t
|
||||
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
|
||||
stretched_t = 1 - (one_minus_z / scale_factor)
|
||||
return stretched_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential
|
||||
def _time_shift_exponential(self, mu, sigma, t):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear
|
||||
def _time_shift_linear(self, mu, sigma, t):
|
||||
return mu / (mu + (1 / t - 1) ** sigma)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from ..utils import _add_variant
|
||||
from .import_utils import is_flashpack_available
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_flashpack(
|
||||
model,
|
||||
save_directory: str,
|
||||
variant: Optional[str] = None,
|
||||
is_main_process: bool = True,
|
||||
):
|
||||
"""
|
||||
Save model weights in FlashPack format along with a metadata config.
|
||||
|
||||
Args:
|
||||
model: Diffusers model instance
|
||||
save_directory (`str`): Directory to save weights
|
||||
variant (`str`, *optional*): Model variant
|
||||
"""
|
||||
if not is_flashpack_available():
|
||||
raise ImportError(
|
||||
"The `use_flashpack=True` argument requires the `flashpack` package. "
|
||||
"Install it with `pip install flashpack`."
|
||||
)
|
||||
|
||||
from flashpack import pack_to_file
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
weights_name = _add_variant("model.flashpack", variant)
|
||||
weights_path = os.path.join(save_directory, weights_name)
|
||||
config_path = os.path.join(save_directory, "flashpack_config.json")
|
||||
|
||||
try:
|
||||
target_dtype = getattr(model, "dtype", None)
|
||||
logger.warning(f"Dtype used for FlashPack save: {target_dtype}")
|
||||
|
||||
# 1. Save binary weights
|
||||
pack_to_file(model, weights_path, target_dtype=target_dtype)
|
||||
|
||||
# 2. Save config metadata (best-effort)
|
||||
if hasattr(model, "config"):
|
||||
try:
|
||||
if hasattr(model.config, "to_dict"):
|
||||
config_data = model.config.to_dict()
|
||||
else:
|
||||
config_data = dict(model.config)
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f, indent=4)
|
||||
|
||||
except Exception as config_err:
|
||||
logger.warning(f"FlashPack weights saved, but config serialization failed: {config_err}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save weights in FlashPack format: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def load_flashpack(model, flashpack_file: str):
|
||||
"""
|
||||
Assign FlashPack weights from a file into an initialized PyTorch model.
|
||||
"""
|
||||
if not is_flashpack_available():
|
||||
raise ImportError("FlashPack weights require the `flashpack` package. Install with `pip install flashpack`.")
|
||||
|
||||
from flashpack import assign_from_file
|
||||
|
||||
logger.warning(f"Loading FlashPack weights from {flashpack_file}")
|
||||
|
||||
try:
|
||||
assign_from_file(model, flashpack_file)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load FlashPack weights from {flashpack_file}") from e
|
||||
@@ -231,7 +231,6 @@ _aiter_available, _aiter_version = _is_package_available("aiter")
|
||||
_kornia_available, _kornia_version = _is_package_available("kornia")
|
||||
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
|
||||
_av_available, _av_version = _is_package_available("av")
|
||||
_flashpack_available, _flashpack_version = _is_package_available("flashpack")
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
@@ -426,10 +425,6 @@ def is_av_available():
|
||||
return _av_available
|
||||
|
||||
|
||||
def is_flashpack_available():
|
||||
return _flashpack_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
@@ -947,16 +942,6 @@ def is_aiter_version(operation: str, version: str):
|
||||
return compare_versions(parse(_aiter_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_flashpack_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current flashpack version to a given reference with an operation.
|
||||
"""
|
||||
if not _flashpack_available:
|
||||
return False
|
||||
return compare_versions(parse(_flashpack_version), operation, version)
|
||||
|
||||
|
||||
def get_objects_from_module(module):
|
||||
"""
|
||||
Returns a dict of object names and values in a module, while skipping private/internal objects
|
||||
|
||||
Reference in New Issue
Block a user