mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-24 20:46:08 +08:00
Compare commits
1 Commits
modular-lo
...
xformers-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae0cc0b71f |
@@ -14,7 +14,7 @@
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -286,7 +286,9 @@ class AttentionBlock(nn.Module):
|
||||
|
||||
self._use_memory_efficient_attention_xformers = False
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||
):
|
||||
if not is_xformers_available():
|
||||
raise ModuleNotFoundError(
|
||||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
||||
@@ -457,7 +459,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
f" correctly and a GPU is available: {e}"
|
||||
)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||
):
|
||||
if not is_xformers_available():
|
||||
print("Here is how to install it")
|
||||
raise ModuleNotFoundError(
|
||||
@@ -481,7 +485,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
except Exception as e:
|
||||
raise e
|
||||
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
self.attn1._attention_op = attention_op
|
||||
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
self.attn2._attention_op = attention_op
|
||||
|
||||
def forward(self, hidden_states, context=None, timestep=None):
|
||||
# 1. Self-Attention
|
||||
@@ -545,6 +551,7 @@ class CrossAttention(nn.Module):
|
||||
self.sliceable_head_dim = heads
|
||||
self._slice_size = None
|
||||
self._use_memory_efficient_attention_xformers = False
|
||||
self._attention_op = None
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
@@ -671,7 +678,9 @@ class CrossAttention(nn.Module):
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, key, value, attn_bias=None, op=self._attention_op
|
||||
)
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ import inspect
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -812,7 +812,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
self._progress_bar_config = kwargs
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
@@ -822,7 +822,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
self.set_use_memory_efficient_attention_xformers(True)
|
||||
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
||||
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
@@ -830,13 +830,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||
"""
|
||||
self.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, valid: bool, attention_op: Optional[Callable] = None
|
||||
) -> None:
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
||||
# gets the message
|
||||
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
||||
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
||||
module.set_use_memory_efficient_attention_xformers(valid)
|
||||
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_mem_eff(child)
|
||||
|
||||
Reference in New Issue
Block a user