mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
17 Commits
diffusers-
...
fix/lora-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
62ece6ab5a | ||
|
|
7722f5b67c | ||
|
|
0281e85827 | ||
|
|
9106e66382 | ||
|
|
2fb3d141c2 | ||
|
|
6e1b06c01c | ||
|
|
50f2544697 | ||
|
|
fc5fc8c8d2 | ||
|
|
030bb528ba | ||
|
|
f803d3d1f5 | ||
|
|
3758d7a8b0 | ||
|
|
3a794b54c9 | ||
|
|
fe623f3bea | ||
|
|
bc65f829b7 | ||
|
|
c22be1a557 | ||
|
|
05f716d4ac | ||
|
|
25b0d5b8c4 |
@@ -755,17 +755,26 @@ def main(args):
|
||||
# Set the `lora_layer` attribute of the attention-related matrices.
|
||||
attn_module.to_q.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
|
||||
in_features=attn_module.to_q.in_features,
|
||||
out_features=attn_module.to_q.out_features,
|
||||
rank=args.rank,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
attn_module.to_k.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
|
||||
in_features=attn_module.to_k.in_features,
|
||||
out_features=attn_module.to_k.out_features,
|
||||
rank=args.rank,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
attn_module.to_v.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
|
||||
in_features=attn_module.to_v.in_features,
|
||||
out_features=attn_module.to_v.out_features,
|
||||
rank=args.rank,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
attn_module.to_out[0].set_lora_layer(
|
||||
@@ -773,6 +782,7 @@ def main(args):
|
||||
in_features=attn_module.to_out[0].in_features,
|
||||
out_features=attn_module.to_out[0].out_features,
|
||||
rank=args.rank,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -207,6 +207,10 @@ class BasicTransformerBlock(nn.Module):
|
||||
)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
# print("After first norm")
|
||||
# print(f"hidden_states: {hidden_states.dtype}")
|
||||
# print(f"norm_hidden_states: {norm_hidden_states.dtype}")
|
||||
# print(f"encoder_hidden_states: {norm_hidden_states.dtype}")
|
||||
|
||||
# 1. Retrieve lora scale.
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
@@ -223,7 +227,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
)
|
||||
if self.use_ada_layer_norm_zero:
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
# print(f"attn_output: {attn_output.dtype}")
|
||||
hidden_states = attn_output + hidden_states
|
||||
# print(f"attn_output: {attn_output.dtype}")
|
||||
|
||||
# 2.5 GLIGEN Control
|
||||
if gligen_kwargs is not None:
|
||||
|
||||
@@ -84,6 +84,7 @@ class LoRALinearLayer(nn.Module):
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# print(f"From {self.__class__.__name__}: hidden_states: {hidden_states.dtype}")
|
||||
orig_dtype = hidden_states.dtype
|
||||
dtype = self.down.weight.dtype
|
||||
|
||||
@@ -93,7 +94,9 @@ class LoRALinearLayer(nn.Module):
|
||||
if self.network_alpha is not None:
|
||||
up_hidden_states *= self.network_alpha / self.rank
|
||||
|
||||
return up_hidden_states.to(orig_dtype)
|
||||
out = up_hidden_states.to(orig_dtype)
|
||||
# print(f"From {self.__class__.__name__}: out: {out.dtype}")
|
||||
return out
|
||||
|
||||
|
||||
class LoRAConv2dLayer(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user