Compare commits

...

2 Commits

Author SHA1 Message Date
Patrick von Platen
23d616de52 up 2023-03-16 16:12:14 +00:00
Patrick von Platen
cfc129d669 up 2023-03-16 16:11:33 +00:00
2 changed files with 10 additions and 3 deletions

View File

@@ -808,7 +808,7 @@ class CrossAttnDownBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward( def forward(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, down_block_res=None, cross_attention_kwargs=None
): ):
# TODO(Patrick, William) - attention mask is not used # TODO(Patrick, William) - attention mask is not used
output_states = () output_states = ()
@@ -843,6 +843,8 @@ class CrossAttnDownBlock2D(nn.Module):
output_states += (hidden_states,) output_states += (hidden_states,)
if self.downsamplers is not None: if self.downsamplers is not None:
if down_block_res is not None:
hidden_states += down_block_res
for downsampler in self.downsamplers: for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states) hidden_states = downsampler(hidden_states)

View File

@@ -576,23 +576,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_t2i = mid_block_additional_residual is None and down_block_additional_residuals is not None
# 3. down # 3. down
down_block_res_samples = (sample,) down_block_res_samples = (sample,)
for downsample_block in self.down_blocks: for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
# find out whether `is_t2i` depending on the shape of the residual connections
kwargs = {} if not is_t2i else {"down_block_res": down_block_additional_residuals.pop()}
sample, res_samples = downsample_block( sample, res_samples = downsample_block(
hidden_states=sample, hidden_states=sample,
temb=emb, temb=emb,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
**kwargs,
) )
else: else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples down_block_res_samples += res_samples
if down_block_additional_residuals is not None: if is_controlnet:
new_down_block_res_samples = () new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip( for down_block_res_sample, down_block_additional_residual in zip(
@@ -613,7 +618,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
) )
if mid_block_additional_residual is not None: if is_controlnet:
sample = sample + mid_block_additional_residual sample = sample + mid_block_additional_residual
# 5. up # 5. up