Compare commits

...

3 Commits

Author SHA1 Message Date
DN6
1ca9acc269 update 2024-03-07 14:08:39 +05:30
DN6
150d22821f update 2024-03-07 13:48:01 +05:30
DN6
6f3fd3bd51 update 2024-03-07 13:20:16 +05:30

View File

@@ -592,13 +592,15 @@ class StableCascadeUNet(ModelMixin, ConfigMixin):
# Model Blocks # Model Blocks
x = self.embedding(sample) x = self.embedding(sample)
# Interpolate operations are always run in fp32 in the original implementation
if hasattr(self, "effnet_mapper") and effnet is not None: if hasattr(self, "effnet_mapper") and effnet is not None:
x = x + self.effnet_mapper( x = x + self.effnet_mapper(
nn.functional.interpolate(effnet, size=x.shape[-2:], mode="bilinear", align_corners=True) nn.functional.interpolate(effnet.float(), size=x.shape[-2:], mode="bilinear", align_corners=True)
) )
if hasattr(self, "pixels_mapper"): if hasattr(self, "pixels_mapper"):
x = x + nn.functional.interpolate( x = x + nn.functional.interpolate(
self.pixels_mapper(pixels), size=x.shape[-2:], mode="bilinear", align_corners=True self.pixels_mapper(pixels).float(), size=x.shape[-2:], mode="bilinear", align_corners=True
) )
level_outputs = self._down_encode(x, timestep_ratio_embed, clip) level_outputs = self._down_encode(x, timestep_ratio_embed, clip)
x = self._up_decode(level_outputs, timestep_ratio_embed, clip) x = self._up_decode(level_outputs, timestep_ratio_embed, clip)