Compare commits

...

3 Commits

Author SHA1 Message Date
DN6
1ca9acc269 update 2024-03-07 14:08:39 +05:30
DN6
150d22821f update 2024-03-07 13:48:01 +05:30
DN6
6f3fd3bd51 update 2024-03-07 13:20:16 +05:30

View File

@@ -592,13 +592,15 @@ class StableCascadeUNet(ModelMixin, ConfigMixin):
# Model Blocks
x = self.embedding(sample)
# Interpolate operations are always run in fp32 in the original implementation
if hasattr(self, "effnet_mapper") and effnet is not None:
x = x + self.effnet_mapper(
nn.functional.interpolate(effnet, size=x.shape[-2:], mode="bilinear", align_corners=True)
nn.functional.interpolate(effnet.float(), size=x.shape[-2:], mode="bilinear", align_corners=True)
)
if hasattr(self, "pixels_mapper"):
x = x + nn.functional.interpolate(
self.pixels_mapper(pixels), size=x.shape[-2:], mode="bilinear", align_corners=True
self.pixels_mapper(pixels).float(), size=x.shape[-2:], mode="bilinear", align_corners=True
)
level_outputs = self._down_encode(x, timestep_ratio_embed, clip)
x = self._up_decode(level_outputs, timestep_ratio_embed, clip)