Compare commits

...

6 Commits

Author SHA1 Message Date
Sayak Paul
ac1b61dbe8 Merge branch 'main' into fix/transformers2d-config 2024-03-08 09:16:38 +05:30
sayakpaul
b7b58686dc fix: norm _type handling 2024-03-05 18:23:24 +05:30
sayakpaul
5ddcb5b8fb more check 2024-03-05 13:10:19 +05:30
sayakpaul
16627e9ac9 add comment on supported norm_types in transformers2d 2024-03-05 12:09:00 +05:30
sayakpaul
091a7d55c8 Merge branch 'main' into fix/transformers2d-config 2024-03-05 11:55:48 +05:30
sayakpaul
db1f7021eb throw error when patch inputs and layernorm are provided for transformers2d. 2024-03-04 08:55:12 +05:30
2 changed files with 12 additions and 2 deletions

View File

@@ -143,7 +143,7 @@ class BasicTransformerBlock(nn.Module):
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'layer_norm_i2vgen'
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",

View File

@@ -92,7 +92,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
attention_type: str = "default",
@@ -100,6 +100,16 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
interpolation_scale: float = None,
):
super().__init__()
if patch_size is not None:
if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]:
raise NotImplementedError(
f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
)
elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None:
raise ValueError(
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
)
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim