mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
* [WIP] Remove code snippets containing `is_safetensors_available()` * Modifying `import_utils.py` * update pipeline tests for safetensor default * fix test related to cached requests * address import nits --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
72 lines
2.3 KiB
Python
72 lines
2.3 KiB
Python
import argparse
|
|
|
|
import safetensors.torch
|
|
|
|
from diffusers import AutoencoderTiny
|
|
|
|
|
|
"""
|
|
Example - From the diffusers root directory:
|
|
|
|
Download the weights:
|
|
```sh
|
|
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_encoder.safetensors
|
|
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_decoder.safetensors
|
|
```
|
|
|
|
Convert the model:
|
|
```sh
|
|
$ python scripts/convert_tiny_autoencoder_to_diffusers.py \
|
|
--encoder_ckpt_path taesd_encoder.safetensors \
|
|
--decoder_ckpt_path taesd_decoder.safetensors \
|
|
--dump_path taesd-diffusers
|
|
```
|
|
"""
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
|
parser.add_argument(
|
|
"--encoder_ckpt_path",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="Path to the encoder ckpt.",
|
|
)
|
|
parser.add_argument(
|
|
"--decoder_ckpt_path",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="Path to the decoder ckpt.",
|
|
)
|
|
parser.add_argument(
|
|
"--use_safetensors", action="store_true", help="Whether to serialize in the safetensors format."
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
print("Loading the original state_dicts of the encoder and the decoder...")
|
|
encoder_state_dict = safetensors.torch.load_file(args.encoder_ckpt_path)
|
|
decoder_state_dict = safetensors.torch.load_file(args.decoder_ckpt_path)
|
|
|
|
print("Populating the state_dicts in the diffusers format...")
|
|
tiny_autoencoder = AutoencoderTiny()
|
|
new_state_dict = {}
|
|
|
|
# Modify the encoder state dict.
|
|
for k in encoder_state_dict:
|
|
new_state_dict.update({f"encoder.layers.{k}": encoder_state_dict[k]})
|
|
|
|
# Modify the decoder state dict.
|
|
for k in decoder_state_dict:
|
|
layer_id = int(k.split(".")[0]) - 1
|
|
new_k = str(layer_id) + "." + ".".join(k.split(".")[1:])
|
|
new_state_dict.update({f"decoder.layers.{new_k}": decoder_state_dict[k]})
|
|
|
|
# Assertion tests with the original implementation can be found here:
|
|
# https://gist.github.com/sayakpaul/337b0988f08bd2cf2b248206f760e28f
|
|
tiny_autoencoder.load_state_dict(new_state_dict)
|
|
print("Population successful, serializing...")
|
|
tiny_autoencoder.save_pretrained(args.dump_path, safe_serialization=args.use_safetensors)
|