Compare commits

...

2 Commits

Author SHA1 Message Date
Dhruv Nair
53bc88e1d8 make style 2024-02-26 03:49:17 +00:00
Dhruv Nair
8cf57b3638 update 2024-02-23 11:58:15 +00:00

View File

@@ -4,6 +4,7 @@ import math
import os
from copy import deepcopy
import requests
import torch
from audio_diffusion.models import DiffusionAttnUnet1D
from diffusion import sampling
@@ -73,9 +74,14 @@ class DiffusionUncond(nn.Module):
def download(model_name):
url = MODELS_MAP[model_name]["url"]
os.system(f"wget {url} ./")
r = requests.get(url, stream=True)
return f"./{model_name}.ckpt"
local_filename = f"./{model_name}.ckpt"
with open(local_filename, "wb") as fp:
for chunk in r.iter_content(chunk_size=8192):
fp.write(chunk)
return local_filename
DOWN_NUM_TO_LAYER = {