Compare commits

..

3 Commits

Author SHA1 Message Date
Daniel Gu
ae666442ac make style and make quality 2026-02-17 02:58:28 +01:00
Daniel Gu
010104bb42 Remove xfail for PRX pipeline tests as they appear to work on transformers>4.57.1 2026-02-17 02:57:41 +01:00
Daniel Gu
c119560810 Guard ftfy import with is_ftfy_available 2026-02-17 02:56:45 +01:00
4 changed files with 8 additions and 44 deletions

View File

@@ -1117,26 +1117,6 @@ def _sage_attention_backward_op(
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mask: torch.Tensor | None = None):
# Skip Attention Mask if all values are 1, `None` mask can speedup the computation
if attn_mask is not None and torch.all(attn_mask != 0):
attn_mask = None
# Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
# https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md
if (
attn_mask is not None
and attn_mask.ndim == 2
and attn_mask.shape[0] == query.shape[0]
and attn_mask.shape[1] == key.shape[1]
):
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
attn_mask = ~attn_mask.to(torch.bool)
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()
return attn_mask
def _npu_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
@@ -1154,14 +1134,11 @@ def _npu_attention_forward_op(
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask)
out = npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
atten_mask=attn_mask,
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
@@ -2691,17 +2668,16 @@ def _native_npu_attention(
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for NPU attention")
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
if _parallel_config is None:
attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask)
out = npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
atten_mask=attn_mask,
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
@@ -2716,7 +2692,7 @@ def _native_npu_attention(
query,
key,
value,
attn_mask,
None,
dropout_p,
None,
scale,

View File

@@ -164,11 +164,7 @@ def compute_text_seq_len_from_mask(
position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
has_active = encoder_hidden_states_mask.any(dim=1)
per_sample_len = torch.where(
has_active,
active_positions.max(dim=1).values + 1,
torch.as_tensor(text_seq_len, device=encoder_hidden_states.device),
)
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
return text_seq_len, per_sample_len, encoder_hidden_states_mask

View File

@@ -18,7 +18,6 @@ import re
import urllib.parse as ul
from typing import Callable
import ftfy
import torch
from transformers import (
AutoTokenizer,
@@ -34,13 +33,13 @@ from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.prx.pipeline_output import PRXPipelineOutput
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
logging,
replace_example_docstring,
)
from diffusers.utils import is_ftfy_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
if is_ftfy_available():
import ftfy
DEFAULT_RESOLUTION = 512
ASPECT_RATIO_256_BIN = {

View File

@@ -1,7 +1,6 @@
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer
from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
@@ -11,17 +10,11 @@ from diffusers.models import AutoencoderDC, AutoencoderKL
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.prx.pipeline_prx import PRXPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_transformers_version
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@pytest.mark.xfail(
condition=is_transformers_version(">", "4.57.1"),
reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544",
strict=False,
)
class PRXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = PRXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}