Compare commits

...

1 Commits

Author SHA1 Message Date
Dhruv Nair
4ea7210ca9 update 2024-08-12 12:31:51 +02:00

View File

@@ -311,6 +311,27 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def enable_mixed_precision_inference(self, upcast_dtype=None, downcast_dtype=None):
downcast_dtype = downcast_dtype or torch.float16
upcast_dtype = upcast_dtype or torch.float32
def pre_hook_fn(module, input):
module = module.to(upcast_dtype)
def hook_fn(module, input, output):
module = module.to(downcast_dtype)
def fn_recursive_upcast(module: torch.nn.Module):
if not list(module.children()):
module.register_forward_pre_hook(pre_hook_fn)
module.register_forward_hook(hook_fn)
for child in module.children():
fn_recursive_upcast(child)
for module in self.children():
fn_recursive_upcast(module)
def forward(
self,
hidden_states: torch.Tensor,