diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 37e61e2aaf..43615dc664 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -671,6 +671,8 @@ class StableDiffusionXLPipeline( return add_time_ids def upcast_vae(self): + from ...models.attention_processor import FusedAttnProcessor2_0 + dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( @@ -680,6 +682,7 @@ class StableDiffusionXLPipeline( XFormersAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0, + FusedAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need @@ -838,25 +841,25 @@ class StableDiffusionXLPipeline( def num_timesteps(self): return self._num_timesteps - def _enable_bfloat16_for_vae(self): - self.vae = self.vae.to(torch.bfloat16) - self.is_vae_in_blfoat16 = True + # def _enable_bfloat16_for_vae(self): + # self.vae = self.vae.to(torch.bfloat16) + # self.is_vae_in_blfoat16 = True - def _change_to_group_norm_32(self): - class GroupNorm32(torch.nn.GroupNorm): - def forward(self, x): - return super().forward(x.float()).type(x.dtype) + # def _change_to_group_norm_32(self): + # class GroupNorm32(torch.nn.GroupNorm): + # def forward(self, x): + # return super().forward(x.float()).type(x.dtype) - def recursive_fn(model): - for name, module in model.named_children(): - if isinstance(module, torch.nn.GroupNorm): - new_gn = GroupNorm32(module.num_groups, module.num_channels, module.eps, module.affine) - setattr(model, name, new_gn) - elif len(list(module.children())) > 0: - recursive_fn(module) + # def recursive_fn(model): + # for name, module in model.named_children(): + # if isinstance(module, torch.nn.GroupNorm): + # new_gn = GroupNorm32(module.num_groups, module.num_channels, module.eps, module.affine) + # setattr(model, name, new_gn) + # elif len(list(module.children())) > 0: + # recursive_fn(module) - recursive_fn(self.unet) - recursive_fn(self.vae) + # recursive_fn(self.unet) + # recursive_fn(self.vae) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -1267,8 +1270,8 @@ class StableDiffusionXLPipeline( self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - if hasattr(self, "is_vae_in_blfoat16") and self.is_vae_in_blfoat16: - latents = latents.to(self.vae.dtype) + # if hasattr(self, "is_vae_in_blfoat16") and self.is_vae_in_blfoat16: + # latents = latents.to(self.vae.dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] # cast back to fp16 if needed