mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
18 Commits
sage-kerne
...
fa3-tests
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a78bf5a070 | ||
|
|
49b150d277 | ||
|
|
b6959efd7d | ||
|
|
886b85e45d | ||
|
|
3a139f4329 | ||
|
|
28c6d8824c | ||
|
|
f991caafcf | ||
|
|
08e6a1a7b4 | ||
|
|
444a35d40b | ||
|
|
097912e315 | ||
|
|
0c677349f9 | ||
|
|
351b6d0cbb | ||
|
|
536dbf4eac | ||
|
|
f7d06416be | ||
|
|
a32de0a23b | ||
|
|
25308b8490 | ||
|
|
ecba31b4f7 | ||
|
|
b766531e0f |
72
benchmark_fa3.py
Normal file
72
benchmark_fa3.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import torch
|
||||
from fa3_processor import FA3AttnProcessor
|
||||
from diffusers import DiffusionPipeline
|
||||
import argparse
|
||||
import torch.utils.benchmark as benchmark
|
||||
import gc
|
||||
import json
|
||||
|
||||
def flush():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
def bytes_to_giga_bytes(bytes):
|
||||
return f"{(bytes / 1024 / 1024 / 1024):.3f}"
|
||||
|
||||
def benchmark_fn(f, *args, **kwargs):
|
||||
t0 = benchmark.Timer(
|
||||
stmt="f(*args, **kwargs)",
|
||||
globals={"args": args, "kwargs": kwargs, "f": f},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
return f"{(t0.blocked_autorange().mean):.3f}"
|
||||
|
||||
def load_pipeline(args):
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
if args.fa3:
|
||||
pipeline.transformer.set_attn_processor(FA3AttnProcessor())
|
||||
pipeline.vae.set_attn_processor(FA3AttnProcessor())
|
||||
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
return pipeline
|
||||
|
||||
def run_pipeline(pipeline, args):
|
||||
_ = pipeline(
|
||||
prompt="a cat with tiger-like looks",
|
||||
num_images_per_prompt=args.batch_size,
|
||||
guidance_scale=7.5
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fa3", default=0, type=int)
|
||||
parser.add_argument("--batch_size", default=1, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
flush()
|
||||
|
||||
pipeline = load_pipeline(args)
|
||||
|
||||
for _ in range(3):
|
||||
run_pipeline(pipeline, args)
|
||||
|
||||
time = benchmark_fn(run_pipeline, pipeline, args)
|
||||
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated())
|
||||
data_dict = dict(time=time, memory=memory)
|
||||
print(f"FA3: {bool(args.fa3)} Time: {time} seconds Memory: {memory} GB")
|
||||
|
||||
filename_prefix = f"fa3@{args.fa3}-bs@{args.batch_size}"
|
||||
with open(f"{filename_prefix}.json", "w") as f:
|
||||
json.dump(data_dict, f)
|
||||
|
||||
image = pipeline(
|
||||
prompt="a cat with tiger-like looks",
|
||||
num_images_per_prompt=args.batch_size,
|
||||
num_inference_steps=25,
|
||||
guidance_scale=7.5
|
||||
).images[0]
|
||||
image.save(f"{filename_prefix}.png")
|
||||
95
fa3_processor.py
Normal file
95
fa3_processor.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
from flash_attn_interface import flash_attn_func
|
||||
|
||||
class FA3AttnProcessor:
|
||||
r"""
|
||||
Processor for using Flash Attention 3 (FA3) via `flash-attn`.
|
||||
|
||||
To install `flash-attn` that supports FA3, follow:
|
||||
https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release
|
||||
|
||||
Reference: https://tridao.me/blog/2024/flash3/
|
||||
"""
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, key_tokens, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
||||
if attention_mask is not None:
|
||||
# expand our mask's singleton query_tokens dimension:
|
||||
# [batch*heads, 1, key_tokens] ->
|
||||
# [batch*heads, query_tokens, key_tokens]
|
||||
# so that it can be added as a bias onto the attention scores that xformers computes:
|
||||
# [batch*heads, query_tokens, key_tokens]
|
||||
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
||||
_, query_tokens, _ = hidden_states.shape
|
||||
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).contiguous()
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).contiguous()
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).contiguous()
|
||||
|
||||
# nasty hack to make the head number and head dim compatible with FA3.
|
||||
# if attn.heads ==1 and head_dim == 512:
|
||||
# factor = 8
|
||||
# new_head_dim = head_dim // factor
|
||||
# query = query.view(batch_size, -1, factor, new_head_dim)
|
||||
# key = key.view(batch_size, -1, factor, new_head_dim)
|
||||
# value = value.view(batch_size, -1, factor, new_head_dim)
|
||||
hidden_states, _ = flash_attn_func(
|
||||
query, key, value, softmax_scale=attn.scale, causal=False
|
||||
)
|
||||
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -19,6 +19,7 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version, logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention_processor import AttentionProcessor
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -186,6 +187,64 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user