comment out stuff that didn't work

This commit is contained in:
sayakpaul
2023-12-04 12:50:14 +05:30
parent d944d8b108
commit 0432297da6

View File

@@ -671,6 +671,8 @@ class StableDiffusionXLPipeline(
return add_time_ids
def upcast_vae(self):
from ...models.attention_processor import FusedAttnProcessor2_0
dtype = self.vae.dtype
self.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
@@ -680,6 +682,7 @@ class StableDiffusionXLPipeline(
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
FusedAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
@@ -838,25 +841,25 @@ class StableDiffusionXLPipeline(
def num_timesteps(self):
return self._num_timesteps
def _enable_bfloat16_for_vae(self):
self.vae = self.vae.to(torch.bfloat16)
self.is_vae_in_blfoat16 = True
# def _enable_bfloat16_for_vae(self):
# self.vae = self.vae.to(torch.bfloat16)
# self.is_vae_in_blfoat16 = True
def _change_to_group_norm_32(self):
class GroupNorm32(torch.nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
# def _change_to_group_norm_32(self):
# class GroupNorm32(torch.nn.GroupNorm):
# def forward(self, x):
# return super().forward(x.float()).type(x.dtype)
def recursive_fn(model):
for name, module in model.named_children():
if isinstance(module, torch.nn.GroupNorm):
new_gn = GroupNorm32(module.num_groups, module.num_channels, module.eps, module.affine)
setattr(model, name, new_gn)
elif len(list(module.children())) > 0:
recursive_fn(module)
# def recursive_fn(model):
# for name, module in model.named_children():
# if isinstance(module, torch.nn.GroupNorm):
# new_gn = GroupNorm32(module.num_groups, module.num_channels, module.eps, module.affine)
# setattr(model, name, new_gn)
# elif len(list(module.children())) > 0:
# recursive_fn(module)
recursive_fn(self.unet)
recursive_fn(self.vae)
# recursive_fn(self.unet)
# recursive_fn(self.vae)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
@@ -1267,8 +1270,8 @@ class StableDiffusionXLPipeline(
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if hasattr(self, "is_vae_in_blfoat16") and self.is_vae_in_blfoat16:
latents = latents.to(self.vae.dtype)
# if hasattr(self, "is_vae_in_blfoat16") and self.is_vae_in_blfoat16:
# latents = latents.to(self.vae.dtype)
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
# cast back to fp16 if needed