mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-14 07:35:41 +08:00
Compare commits
3 Commits
model-test
...
sayakpaul-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
537d4de2cc | ||
|
|
9d68742214 | ||
|
|
f1a93c765f |
@@ -314,6 +314,35 @@ Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
|
||||
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
|
||||
```
|
||||
|
||||
### Unified Attention
|
||||
|
||||
[Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses's *all-to-all* communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the *all-to-all* to restore the original layout.
|
||||
|
||||
This hybrid approach leverages the strengths of both methods:
|
||||
- **Ulysses Attention** efficiently parallelizes across attention heads
|
||||
- **Ring Attention** handles very long sequences with minimal memory overhead
|
||||
- Together, they enable 2D parallelization across both heads and sequence dimensions
|
||||
|
||||
[`ContextParallelConfig`] supports Unified Attention by specifying both `ulysses_degree` and `ring_degree`. The total number of devices used is `ulysses_degree * ring_degree`, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping).
|
||||
Pass the [`ContextParallelConfig`] with both `ulysses_degree` and `ring_degree` set to bigger than 1 to [`~ModelMixin.enable_parallelism`].
|
||||
|
||||
```py
|
||||
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ring_degree=2))
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices).
|
||||
|
||||
We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](https://github.com/huggingface/diffusers/pull/12693#issuecomment-3694727532) on a node of 4 H100 GPUs. The results are summarized as follows:
|
||||
|
||||
| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) |
|
||||
|--------------------|------------------|-------------|------------------|
|
||||
| ulysses | 6670.789 | 7.50 | 33.85 |
|
||||
| ring | 13076.492 | 3.82 | 56.02 |
|
||||
| unified_balanced | 11068.705 | 4.52 | 33.85 |
|
||||
|
||||
From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention.
|
||||
|
||||
### parallel_config
|
||||
|
||||
Pass `parallel_config` during model initialization to enable context parallelism.
|
||||
|
||||
@@ -214,7 +214,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
@@ -641,7 +641,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
@@ -1081,7 +1081,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -1377,7 +1377,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -1659,7 +1659,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
if not (has_lora_keys or has_norm_keys):
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
transformer_lora_state_dict = {
|
||||
k: state_dict.get(k)
|
||||
@@ -2506,7 +2506,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -2703,7 +2703,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -2906,7 +2906,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3115,7 +3115,7 @@ class LTX2LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
transformer_peft_state_dict = {
|
||||
k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")
|
||||
@@ -3333,7 +3333,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3536,7 +3536,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3740,7 +3740,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3940,7 +3940,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -4194,7 +4194,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
|
||||
if load_into_transformer_2:
|
||||
@@ -4471,7 +4471,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
|
||||
if load_into_transformer_2:
|
||||
@@ -4691,7 +4691,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -4894,7 +4894,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -5100,7 +5100,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -5306,7 +5306,7 @@ class ZImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -5509,7 +5509,7 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
|
||||
@@ -90,10 +90,6 @@ class ContextParallelConfig:
|
||||
)
|
||||
if self.ring_degree < 1 or self.ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if self.ring_degree > 1 and self.ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
if self.rotate_method != "allgather":
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
|
||||
@@ -1177,6 +1177,103 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
|
||||
def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
|
||||
"""
|
||||
Perform dimension sharding / reassembly across processes using _all_to_all_single.
|
||||
|
||||
This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or
|
||||
head dimension flexibly by accepting scatter_idx and gather_idx.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor):
|
||||
Input tensor. Expected shapes:
|
||||
- When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim)
|
||||
- When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim)
|
||||
scatter_idx (int) :
|
||||
Dimension along which the tensor is partitioned before all-to-all.
|
||||
gather_idx (int):
|
||||
Dimension along which the output is reassembled after all-to-all.
|
||||
group :
|
||||
Distributed process group for the Ulysses group.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tensor with globally exchanged dimensions.
|
||||
- For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim)
|
||||
- For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim)
|
||||
"""
|
||||
group_world_size = torch.distributed.get_world_size(group)
|
||||
|
||||
if scatter_idx == 2 and gather_idx == 1:
|
||||
# Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence
|
||||
# dimension and scatters head dimension
|
||||
batch_size, seq_len_local, num_heads, head_dim = x.shape
|
||||
seq_len = seq_len_local * group_world_size
|
||||
num_heads_local = num_heads // group_world_size
|
||||
|
||||
# B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
|
||||
x_temp = (
|
||||
x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim)
|
||||
.transpose(0, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
if group_world_size > 1:
|
||||
out = _all_to_all_single(x_temp, group=group)
|
||||
else:
|
||||
out = x_temp
|
||||
# group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
|
||||
out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous()
|
||||
out = out.reshape(batch_size, seq_len, num_heads_local, head_dim)
|
||||
return out
|
||||
elif scatter_idx == 1 and gather_idx == 2:
|
||||
# Used after ulysses sequence parallel in unified SP. gathers the head dimension
|
||||
# scatters back the sequence dimension.
|
||||
batch_size, seq_len, num_heads_local, head_dim = x.shape
|
||||
num_heads = num_heads_local * group_world_size
|
||||
seq_len_local = seq_len // group_world_size
|
||||
|
||||
# B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
|
||||
x_temp = (
|
||||
x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim)
|
||||
.permute(1, 3, 2, 0, 4)
|
||||
.reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim)
|
||||
)
|
||||
|
||||
if group_world_size > 1:
|
||||
output = _all_to_all_single(x_temp, group)
|
||||
else:
|
||||
output = x_temp
|
||||
output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous()
|
||||
output = output.reshape(batch_size, seq_len_local, num_heads, head_dim)
|
||||
return output
|
||||
else:
|
||||
raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.")
|
||||
|
||||
|
||||
class SeqAllToAllDim(torch.autograd.Function):
|
||||
"""
|
||||
all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange
|
||||
for more info.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, group, input, scatter_id=2, gather_id=1):
|
||||
ctx.group = group
|
||||
ctx.scatter_id = scatter_id
|
||||
ctx.gather_id = gather_id
|
||||
return _all_to_all_dim_exchange(input, scatter_id, gather_id, group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_outputs):
|
||||
grad_input = SeqAllToAllDim.apply(
|
||||
ctx.group,
|
||||
grad_outputs,
|
||||
ctx.gather_id, # reversed
|
||||
ctx.scatter_id, # reversed
|
||||
)
|
||||
return (None, grad_input, None, None)
|
||||
|
||||
|
||||
class TemplatedRingAttention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
@@ -1237,7 +1334,10 @@ class TemplatedRingAttention(torch.autograd.Function):
|
||||
out = out.to(torch.float32)
|
||||
lse = lse.to(torch.float32)
|
||||
|
||||
lse = lse.unsqueeze(-1)
|
||||
# Refer to:
|
||||
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
|
||||
if is_torch_version("<", "2.9.0"):
|
||||
lse = lse.unsqueeze(-1)
|
||||
if prev_out is not None:
|
||||
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
|
||||
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
|
||||
@@ -1298,7 +1398,7 @@ class TemplatedRingAttention(torch.autograd.Function):
|
||||
|
||||
grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
|
||||
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
@@ -1393,7 +1493,69 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
|
||||
)
|
||||
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def _templated_unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor],
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
scale: Optional[float],
|
||||
enable_gqa: bool,
|
||||
return_lse: bool,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
scatter_idx: int = 2,
|
||||
gather_idx: int = 1,
|
||||
):
|
||||
"""
|
||||
Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719
|
||||
"""
|
||||
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
|
||||
ulysses_group = ulysses_mesh.get_group()
|
||||
|
||||
query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
|
||||
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
|
||||
value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx)
|
||||
out = TemplatedRingAttention.apply(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
context_layer, lse, *_ = out
|
||||
else:
|
||||
context_layer = out
|
||||
# context_layer is of shape (B, S, H_LOCAL, D)
|
||||
output = SeqAllToAllDim.apply(
|
||||
ulysses_group,
|
||||
context_layer,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
)
|
||||
if return_lse:
|
||||
# lse is of shape (B, S, H_LOCAL, 1)
|
||||
# Refer to:
|
||||
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
|
||||
if is_torch_version("<", "2.9.0"):
|
||||
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
|
||||
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
|
||||
lse = lse.squeeze(-1)
|
||||
return (output, lse)
|
||||
return output
|
||||
|
||||
|
||||
def _templated_context_parallel_attention(
|
||||
@@ -1419,7 +1581,25 @@ def _templated_context_parallel_attention(
|
||||
raise ValueError("GQA is not yet supported for templated attention.")
|
||||
|
||||
# TODO: add support for unified attention with ring/ulysses degree both being > 1
|
||||
if _parallel_config.context_parallel_config.ring_degree > 1:
|
||||
if (
|
||||
_parallel_config.context_parallel_config.ring_degree > 1
|
||||
and _parallel_config.context_parallel_config.ulysses_degree > 1
|
||||
):
|
||||
return _templated_unified_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config,
|
||||
)
|
||||
elif _parallel_config.context_parallel_config.ring_degree > 1:
|
||||
return TemplatedRingAttention.apply(
|
||||
query,
|
||||
key,
|
||||
|
||||
@@ -76,6 +76,8 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 8, 8, 3)
|
||||
@@ -114,23 +116,3 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in AuraFlow.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -87,6 +87,8 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 16, 16, 3)
|
||||
@@ -147,26 +149,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@@ -85,6 +85,8 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"text_encoder",
|
||||
)
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 32, 32, 3)
|
||||
@@ -162,23 +164,3 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -66,6 +66,8 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers"
|
||||
denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 8, 8, 3)
|
||||
@@ -146,23 +148,3 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in Flux2.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -117,6 +117,8 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"text_encoder_2",
|
||||
)
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 32, 32, 3)
|
||||
@@ -172,26 +174,6 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
|
||||
@@ -150,6 +150,8 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_out.0"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 5, 32, 32, 3)
|
||||
@@ -267,27 +269,3 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in LTX2.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
|
||||
def test_simple_inference_save_pretrained_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@@ -76,6 +76,8 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 32, 32, 3)
|
||||
@@ -125,23 +127,3 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in LTXVideo.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -74,6 +74,8 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/dummy-gemma"
|
||||
text_encoder_cls, text_encoder_id = GemmaForCausalLM, "hf-internal-testing/dummy-gemma-diffusers"
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 4, 4, 3)
|
||||
@@ -113,26 +115,6 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@skip_mps
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
|
||||
|
||||
@@ -67,6 +67,8 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 7, 16, 16, 3)
|
||||
@@ -117,26 +119,6 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@@ -69,6 +69,8 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
)
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 8, 8, 3)
|
||||
@@ -107,23 +109,3 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -75,6 +75,8 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma"
|
||||
text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers"
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 32, 32, 3)
|
||||
@@ -117,26 +119,6 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
def test_layerwise_casting_inference_denoiser(self):
|
||||
return super().test_layerwise_casting_inference_denoiser()
|
||||
|
||||
@@ -73,6 +73,8 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 32, 32, 3)
|
||||
@@ -121,23 +123,3 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in Wan.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -85,6 +85,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 16, 16, 3)
|
||||
@@ -139,26 +141,6 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
def test_layerwise_casting_inference_denoiser(self):
|
||||
super().test_layerwise_casting_inference_denoiser()
|
||||
|
||||
|
||||
@@ -75,6 +75,8 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
text_encoder_cls, text_encoder_id = Qwen3Model, None # Will be created inline
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 32, 32, 3)
|
||||
@@ -263,23 +265,3 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in ZImage.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -117,6 +117,7 @@ class PeftLoraLoaderMixinTests:
|
||||
tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, ""
|
||||
tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, ""
|
||||
tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, ""
|
||||
supports_text_encoder_loras = True
|
||||
|
||||
unet_kwargs = None
|
||||
transformer_cls = None
|
||||
@@ -333,6 +334,9 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference with lora attached on the text encoder
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -457,6 +461,9 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference with lora attached on the text encoder + scale argument
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
@@ -494,6 +501,9 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -555,6 +565,9 @@ class PeftLoraLoaderMixinTests:
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA.
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -593,6 +606,9 @@ class PeftLoraLoaderMixinTests:
|
||||
with different ranks and some adapters removed
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, _, _ = self.get_dummy_components()
|
||||
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
|
||||
text_lora_config = LoraConfig(
|
||||
@@ -651,6 +667,9 @@ class PeftLoraLoaderMixinTests:
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user