mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
[AudioLDM] Generalise conversion script (#3328)
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -774,6 +774,8 @@ def load_pipeline_from_original_audioldm_ckpt(
|
||||
extract_ema: bool = False,
|
||||
scheduler_type: str = "ddim",
|
||||
num_in_channels: int = None,
|
||||
model_channels: int = None,
|
||||
num_head_channels: int = None,
|
||||
device: str = None,
|
||||
from_safetensors: bool = False,
|
||||
) -> AudioLDMPipeline:
|
||||
@@ -784,23 +786,36 @@ def load_pipeline_from_original_audioldm_ckpt(
|
||||
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
|
||||
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
|
||||
|
||||
:param checkpoint_path: Path to `.ckpt` file. :param original_config_file: Path to `.yaml` config file
|
||||
corresponding to the original architecture.
|
||||
If `None`, will be automatically instantiated based on default values.
|
||||
:param image_size: The image size that the model was trained on. Use 512 for original AudioLDM checkpoints. :param
|
||||
prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for original
|
||||
AudioLDM checkpoints.
|
||||
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
|
||||
inferred.
|
||||
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
|
||||
"euler-ancestral", "dpm", "ddim"]`.
|
||||
:param extract_ema: Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract
|
||||
the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually
|
||||
yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
|
||||
:param device: The device to use. Pass `None` to determine automatically. :param from_safetensors: If
|
||||
`checkpoint_path` is in `safetensors` format, load checkpoint with safetensors
|
||||
instead of PyTorch.
|
||||
:return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
||||
Args:
|
||||
checkpoint_path (`str`): Path to `.ckpt` file.
|
||||
original_config_file (`str`):
|
||||
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
|
||||
set to the audioldm-s-full-v2 config.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size that the model was trained on.
|
||||
prediction_type (`str`, *optional*):
|
||||
The prediction type that the model was trained on. If `None`, will be automatically
|
||||
inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`.
|
||||
num_in_channels (`int`, *optional*, defaults to None):
|
||||
The number of UNet input channels. If `None`, it will be automatically inferred from the config.
|
||||
model_channels (`int`, *optional*, defaults to None):
|
||||
The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override
|
||||
to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.
|
||||
num_head_channels (`int`, *optional*, defaults to None):
|
||||
The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override
|
||||
to 32 for the small and medium checkpoints, and 64 for the large.
|
||||
scheduler_type (`str`, *optional*, defaults to 'pndm'):
|
||||
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
|
||||
"ddim"]`.
|
||||
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
|
||||
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
|
||||
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
|
||||
inference. Non-EMA weights are usually better to continue fine-tuning.
|
||||
device (`str`, *optional*, defaults to `None`):
|
||||
The device to use. Pass `None` to determine automatically.
|
||||
from_safetensors (`str`, *optional*, defaults to `False`):
|
||||
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
|
||||
return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
||||
"""
|
||||
|
||||
if not is_omegaconf_available():
|
||||
@@ -837,6 +852,12 @@ def load_pipeline_from_original_audioldm_ckpt(
|
||||
if num_in_channels is not None:
|
||||
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
|
||||
if model_channels is not None:
|
||||
original_config["model"]["params"]["unet_config"]["params"]["model_channels"] = model_channels
|
||||
|
||||
if num_head_channels is not None:
|
||||
original_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = num_head_channels
|
||||
|
||||
if (
|
||||
"parameterization" in original_config["model"]["params"]
|
||||
and original_config["model"]["params"]["parameterization"] == "v"
|
||||
@@ -960,6 +981,20 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_channels",
|
||||
default=None,
|
||||
type=int,
|
||||
help="The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override"
|
||||
" to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_head_channels",
|
||||
default=None,
|
||||
type=int,
|
||||
help="The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override"
|
||||
" to 32 for the small and medium checkpoints, and 64 for the large.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler_type",
|
||||
default="ddim",
|
||||
@@ -1009,6 +1044,8 @@ if __name__ == "__main__":
|
||||
extract_ema=args.extract_ema,
|
||||
scheduler_type=args.scheduler_type,
|
||||
num_in_channels=args.num_in_channels,
|
||||
model_channels=args.model_channels,
|
||||
num_head_channels=args.num_head_channels,
|
||||
from_safetensors=args.from_safetensors,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user