mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
Use model_info.id instead of model_info.modelId (#8912)
Mention model_info.id instead of model_info.modelId
This commit is contained in:
@@ -103,12 +103,12 @@ results["google_ddpm_ema_cat_256"] = torch.tensor([
|
||||
|
||||
models = api.list_models(filter="diffusers")
|
||||
for mod in models:
|
||||
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256":
|
||||
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1]
|
||||
if "google" in mod.author or mod.id == "CompVis/ldm-celebahq-256":
|
||||
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.id.split("/")[-1]
|
||||
|
||||
print(f"Started running {mod.modelId}!!!")
|
||||
print(f"Started running {mod.id}!!!")
|
||||
|
||||
if mod.modelId.startswith("CompVis"):
|
||||
if mod.id.startswith("CompVis"):
|
||||
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder="unet")
|
||||
else:
|
||||
model = UNet2DModel.from_pretrained(local_checkpoint)
|
||||
@@ -122,6 +122,6 @@ for mod in models:
|
||||
logits = model(noise, time_step).sample
|
||||
|
||||
assert torch.allclose(
|
||||
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3
|
||||
logits[0, 0, 0, :30], results["_".join("_".join(mod.id.split("/")).split("-"))], atol=1e-3
|
||||
)
|
||||
print(f"{mod.modelId} has passed successfully!!!")
|
||||
print(f"{mod.id} has passed successfully!!!")
|
||||
|
||||
Reference in New Issue
Block a user