Compare commits

...

1 Commits

Author SHA1 Message Date
patil-suraj
ae0cc0b71f allow passing op to xformers attention 2022-12-07 20:10:56 +01:00
2 changed files with 20 additions and 9 deletions

View File

@@ -14,7 +14,7 @@
import math
import warnings
from dataclasses import dataclass
from typing import Optional
from typing import Callable, Optional
import torch
import torch.nn.functional as F
@@ -286,7 +286,9 @@ class AttentionBlock(nn.Module):
self._use_memory_efficient_attention_xformers = False
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
if not is_xformers_available():
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
@@ -457,7 +459,9 @@ class BasicTransformerBlock(nn.Module):
f" correctly and a GPU is available: {e}"
)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
@@ -481,7 +485,9 @@ class BasicTransformerBlock(nn.Module):
except Exception as e:
raise e
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self.attn1._attention_op = attention_op
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self.attn2._attention_op = attention_op
def forward(self, hidden_states, context=None, timestep=None):
# 1. Self-Attention
@@ -545,6 +551,7 @@ class CrossAttention(nn.Module):
self.sliceable_head_dim = heads
self._slice_size = None
self._use_memory_efficient_attention_xformers = False
self._attention_op = None
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
@@ -671,7 +678,9 @@ class CrossAttention(nn.Module):
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=None, op=self._attention_op
)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states

View File

@@ -19,7 +19,7 @@ import inspect
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
@@ -812,7 +812,7 @@ class DiffusionPipeline(ConfigMixin):
def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs
def enable_xformers_memory_efficient_attention(self):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
r"""
Enable memory efficient attention as implemented in xformers.
@@ -822,7 +822,7 @@ class DiffusionPipeline(ConfigMixin):
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.set_use_memory_efficient_attention_xformers(True)
self.set_use_memory_efficient_attention_xformers(True, attention_op)
def disable_xformers_memory_efficient_attention(self):
r"""
@@ -830,13 +830,15 @@ class DiffusionPipeline(ConfigMixin):
"""
self.set_use_memory_efficient_attention_xformers(False)
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
def set_use_memory_efficient_attention_xformers(
self, valid: bool, attention_op: Optional[Callable] = None
) -> None:
# Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
for child in module.children():
fn_recursive_set_mem_eff(child)