Use model_info.id instead of model_info.modelId (#8912)

Mention model_info.id instead of model_info.modelId
This commit is contained in:
Lucain
2024-07-20 16:31:21 +02:00
committed by GitHub
parent fe7948941d
commit 56e772ab7e

View File

@@ -103,12 +103,12 @@ results["google_ddpm_ema_cat_256"] = torch.tensor([
models = api.list_models(filter="diffusers")
for mod in models:
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256":
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1]
if "google" in mod.author or mod.id == "CompVis/ldm-celebahq-256":
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.id.split("/")[-1]
print(f"Started running {mod.modelId}!!!")
print(f"Started running {mod.id}!!!")
if mod.modelId.startswith("CompVis"):
if mod.id.startswith("CompVis"):
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder="unet")
else:
model = UNet2DModel.from_pretrained(local_checkpoint)
@@ -122,6 +122,6 @@ for mod in models:
logits = model(noise, time_step).sample
assert torch.allclose(
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3
logits[0, 0, 0, :30], results["_".join("_".join(mod.id.split("/")).split("-"))], atol=1e-3
)
print(f"{mod.modelId} has passed successfully!!!")
print(f"{mod.id} has passed successfully!!!")