Compare commits

...

1 Commits

Author SHA1 Message Date
Pedro Cuenca
f317695f6b mps: fix torch.where in svd codepath. 2023-12-18 18:58:09 +01:00

View File

@@ -1334,7 +1334,7 @@ class AlphaBlender(nn.Module):
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
torch.ones(1, 1, device=image_only_indicator.device, dtype=self.mix_factor.dtype),
torch.sigmoid(self.mix_factor)[..., None],
)