Compare commits

...

18 Commits

Author SHA1 Message Date
sayakpaul
a78bf5a070 add methods to pixart 2024-07-15 08:55:17 +05:30
sayakpaul
49b150d277 comment hack 2024-07-15 08:35:49 +05:30
sayakpaul
b6959efd7d pixart helps? 2024-07-15 08:35:10 +05:30
sayakpaul
886b85e45d comment 2024-07-15 08:21:30 +05:30
sayakpaul
3a139f4329 okay 2024-07-15 08:20:31 +05:30
sayakpaul
28c6d8824c okay 2024-07-15 08:18:23 +05:30
sayakpaul
f991caafcf okay 2024-07-15 08:12:39 +05:30
sayakpaul
08e6a1a7b4 checking shapes 2024-07-15 08:08:24 +05:30
sayakpaul
444a35d40b reshape 2024-07-15 08:07:01 +05:30
sayakpaul
097912e315 unpack. 2024-07-15 08:06:05 +05:30
sayakpaul
0c677349f9 make contiguous 2024-07-15 08:04:54 +05:30
sayakpaul
351b6d0cbb okay 2024-07-15 08:03:39 +05:30
sayakpaul
536dbf4eac vae 2024-07-15 08:01:48 +05:30
sayakpaul
f7d06416be none 2024-07-15 07:59:31 +05:30
sayakpaul
a32de0a23b correct import 2024-07-15 07:56:40 +05:30
sayakpaul
25308b8490 benchmarking script. 2024-07-14 09:07:51 +05:30
sayakpaul
ecba31b4f7 references. 2024-07-14 08:56:20 +05:30
sayakpaul
b766531e0f add fa-3 attention processor. 2024-07-14 08:54:05 +05:30
3 changed files with 227 additions and 1 deletions

72
benchmark_fa3.py Normal file
View File

@@ -0,0 +1,72 @@
import torch
from fa3_processor import FA3AttnProcessor
from diffusers import DiffusionPipeline
import argparse
import torch.utils.benchmark as benchmark
import gc
import json
def flush():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def bytes_to_giga_bytes(bytes):
return f"{(bytes / 1024 / 1024 / 1024):.3f}"
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
def load_pipeline(args):
pipeline = DiffusionPipeline.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16
).to("cuda")
if args.fa3:
pipeline.transformer.set_attn_processor(FA3AttnProcessor())
pipeline.vae.set_attn_processor(FA3AttnProcessor())
pipeline.set_progress_bar_config(disable=True)
return pipeline
def run_pipeline(pipeline, args):
_ = pipeline(
prompt="a cat with tiger-like looks",
num_images_per_prompt=args.batch_size,
guidance_scale=7.5
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fa3", default=0, type=int)
parser.add_argument("--batch_size", default=1, type=int)
args = parser.parse_args()
flush()
pipeline = load_pipeline(args)
for _ in range(3):
run_pipeline(pipeline, args)
time = benchmark_fn(run_pipeline, pipeline, args)
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated())
data_dict = dict(time=time, memory=memory)
print(f"FA3: {bool(args.fa3)} Time: {time} seconds Memory: {memory} GB")
filename_prefix = f"fa3@{args.fa3}-bs@{args.batch_size}"
with open(f"{filename_prefix}.json", "w") as f:
json.dump(data_dict, f)
image = pipeline(
prompt="a cat with tiger-like looks",
num_images_per_prompt=args.batch_size,
num_inference_steps=25,
guidance_scale=7.5
).images[0]
image.save(f"{filename_prefix}.png")

95
fa3_processor.py Normal file
View File

@@ -0,0 +1,95 @@
import torch
from flash_attn_interface import flash_attn_func
class FA3AttnProcessor:
r"""
Processor for using Flash Attention 3 (FA3) via `flash-attn`.
To install `flash-attn` that supports FA3, follow:
https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release
Reference: https://tridao.me/blog/2024/flash3/
"""
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, key_tokens, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).contiguous()
key = key.view(batch_size, -1, attn.heads, head_dim).contiguous()
value = value.view(batch_size, -1, attn.heads, head_dim).contiguous()
# nasty hack to make the head number and head dim compatible with FA3.
# if attn.heads ==1 and head_dim == 512:
# factor = 8
# new_head_dim = head_dim // factor
# query = query.view(batch_size, -1, factor, new_head_dim)
# key = key.view(batch_size, -1, factor, new_head_dim)
# value = value.view(batch_size, -1, factor, new_head_dim)
hidden_states, _ = flash_attn_func(
query, key, value, softmax_scale=attn.scale, causal=False
)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
import torch
from torch import nn
@@ -19,6 +19,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ..attention import BasicTransformerBlock
from ..attention_processor import AttentionProcessor
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -186,6 +187,64 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward(
self,
hidden_states: torch.Tensor,