Compare commits

...

17 Commits

Author SHA1 Message Date
sayakpaul
62ece6ab5a debug 2023-10-16 11:45:49 +05:30
sayakpaul
7722f5b67c debug 2023-10-16 11:44:45 +05:30
sayakpaul
0281e85827 debug 2023-10-16 11:36:48 +05:30
sayakpaul
9106e66382 debug 2023-10-16 11:32:50 +05:30
sayakpaul
2fb3d141c2 debug 2023-10-16 11:30:39 +05:30
sayakpaul
6e1b06c01c debug 2023-10-16 11:25:46 +05:30
sayakpaul
50f2544697 debug 2023-10-16 11:21:06 +05:30
sayakpaul
fc5fc8c8d2 debug 2023-10-16 10:44:03 +05:30
sayakpaul
030bb528ba debug 2023-10-16 10:39:51 +05:30
sayakpaul
f803d3d1f5 Merge branch 'main' into fix/lora-dtype 2023-10-16 09:46:56 +05:30
Sayak Paul
3758d7a8b0 debug 2023-10-13 15:51:32 +05:30
Sayak Paul
3a794b54c9 derbug 2023-10-13 15:44:39 +05:30
Sayak Paul
fe623f3bea derbug 2023-10-13 15:34:48 +05:30
Sayak Paul
bc65f829b7 debug 2023-10-13 14:41:43 +05:30
Sayak Paul
c22be1a557 debug 2023-10-13 14:36:48 +05:30
Sayak Paul
05f716d4ac debug 2023-10-13 14:12:28 +05:30
Sayak Paul
25b0d5b8c4 debug 2023-10-13 14:09:12 +05:30
3 changed files with 23 additions and 4 deletions

View File

@@ -755,17 +755,26 @@ def main(args):
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
in_features=attn_module.to_q.in_features,
out_features=attn_module.to_q.out_features,
rank=args.rank,
dtype=torch.float32,
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
in_features=attn_module.to_k.in_features,
out_features=attn_module.to_k.out_features,
rank=args.rank,
dtype=torch.float32,
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
in_features=attn_module.to_v.in_features,
out_features=attn_module.to_v.out_features,
rank=args.rank,
dtype=torch.float32,
)
)
attn_module.to_out[0].set_lora_layer(
@@ -773,6 +782,7 @@ def main(args):
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
dtype=torch.float32,
)
)

View File

@@ -207,6 +207,10 @@ class BasicTransformerBlock(nn.Module):
)
else:
norm_hidden_states = self.norm1(hidden_states)
# print("After first norm")
# print(f"hidden_states: {hidden_states.dtype}")
# print(f"norm_hidden_states: {norm_hidden_states.dtype}")
# print(f"encoder_hidden_states: {norm_hidden_states.dtype}")
# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
@@ -223,7 +227,9 @@ class BasicTransformerBlock(nn.Module):
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
# print(f"attn_output: {attn_output.dtype}")
hidden_states = attn_output + hidden_states
# print(f"attn_output: {attn_output.dtype}")
# 2.5 GLIGEN Control
if gligen_kwargs is not None:

View File

@@ -84,6 +84,7 @@ class LoRALinearLayer(nn.Module):
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# print(f"From {self.__class__.__name__}: hidden_states: {hidden_states.dtype}")
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
@@ -93,7 +94,9 @@ class LoRALinearLayer(nn.Module):
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank
return up_hidden_states.to(orig_dtype)
out = up_hidden_states.to(orig_dtype)
# print(f"From {self.__class__.__name__}: out: {out.dtype}")
return out
class LoRAConv2dLayer(nn.Module):