Compare commits

...

3 Commits

Author SHA1 Message Date
Dhruv Nair
b635fb9e1c Merge branch 'main' into animatediff-safetnsors-support 2024-04-29 17:21:23 +05:30
Dhruv Nair
f20ff26950 update 2024-04-29 11:50:54 +00:00
Dhruv Nair
03f101e906 update 2024-04-24 07:42:12 +00:00
2 changed files with 11 additions and 3 deletions

View File

@@ -1,7 +1,7 @@
import argparse
import torch
from safetensors.torch import save_file
from safetensors.torch import load_file, save_file
def convert_motion_module(original_state_dict):
@@ -34,7 +34,10 @@ def get_args():
if __name__ == "__main__":
args = get_args()
state_dict = torch.load(args.ckpt_path, map_location="cpu")
if args.ckpt_path.endswith(".safetensors"):
state_dict = load_file(args.ckpt_path)
else:
state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict.keys():
state_dict = state_dict["state_dict"]

View File

@@ -1,6 +1,7 @@
import argparse
import torch
from safetensors.torch import load_file
from diffusers import MotionAdapter
@@ -38,7 +39,11 @@ def get_args():
if __name__ == "__main__":
args = get_args()
state_dict = torch.load(args.ckpt_path, map_location="cpu")
if args.ckpt_path.endswith(".safetensors"):
state_dict = load_file(args.ckpt_path)
else:
state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict.keys():
state_dict = state_dict["state_dict"]