mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 23:44:30 +08:00
comment out stuff that didn't work
This commit is contained in:
@@ -671,6 +671,8 @@ class StableDiffusionXLPipeline(
|
||||
return add_time_ids
|
||||
|
||||
def upcast_vae(self):
|
||||
from ...models.attention_processor import FusedAttnProcessor2_0
|
||||
|
||||
dtype = self.vae.dtype
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
@@ -680,6 +682,7 @@ class StableDiffusionXLPipeline(
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
@@ -838,25 +841,25 @@ class StableDiffusionXLPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
def _enable_bfloat16_for_vae(self):
|
||||
self.vae = self.vae.to(torch.bfloat16)
|
||||
self.is_vae_in_blfoat16 = True
|
||||
# def _enable_bfloat16_for_vae(self):
|
||||
# self.vae = self.vae.to(torch.bfloat16)
|
||||
# self.is_vae_in_blfoat16 = True
|
||||
|
||||
def _change_to_group_norm_32(self):
|
||||
class GroupNorm32(torch.nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
# def _change_to_group_norm_32(self):
|
||||
# class GroupNorm32(torch.nn.GroupNorm):
|
||||
# def forward(self, x):
|
||||
# return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
def recursive_fn(model):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.GroupNorm):
|
||||
new_gn = GroupNorm32(module.num_groups, module.num_channels, module.eps, module.affine)
|
||||
setattr(model, name, new_gn)
|
||||
elif len(list(module.children())) > 0:
|
||||
recursive_fn(module)
|
||||
# def recursive_fn(model):
|
||||
# for name, module in model.named_children():
|
||||
# if isinstance(module, torch.nn.GroupNorm):
|
||||
# new_gn = GroupNorm32(module.num_groups, module.num_channels, module.eps, module.affine)
|
||||
# setattr(model, name, new_gn)
|
||||
# elif len(list(module.children())) > 0:
|
||||
# recursive_fn(module)
|
||||
|
||||
recursive_fn(self.unet)
|
||||
recursive_fn(self.vae)
|
||||
# recursive_fn(self.unet)
|
||||
# recursive_fn(self.vae)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
@@ -1267,8 +1270,8 @@ class StableDiffusionXLPipeline(
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if hasattr(self, "is_vae_in_blfoat16") and self.is_vae_in_blfoat16:
|
||||
latents = latents.to(self.vae.dtype)
|
||||
# if hasattr(self, "is_vae_in_blfoat16") and self.is_vae_in_blfoat16:
|
||||
# latents = latents.to(self.vae.dtype)
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
|
||||
Reference in New Issue
Block a user