Compare commits

...

2 Commits

Author SHA1 Message Date
Patrick von Platen
23d616de52 up 2023-03-16 16:12:14 +00:00
Patrick von Platen
cfc129d669 up 2023-03-16 16:11:33 +00:00
2 changed files with 10 additions and 3 deletions

View File

@@ -808,7 +808,7 @@ class CrossAttnDownBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, down_block_res=None, cross_attention_kwargs=None
):
# TODO(Patrick, William) - attention mask is not used
output_states = ()
@@ -843,6 +843,8 @@ class CrossAttnDownBlock2D(nn.Module):
output_states += (hidden_states,)
if self.downsamplers is not None:
if down_block_res is not None:
hidden_states += down_block_res
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)

View File

@@ -576,23 +576,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# 2. pre-process
sample = self.conv_in(sample)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_t2i = mid_block_additional_residual is None and down_block_additional_residuals is not None
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
# find out whether `is_t2i` depending on the shape of the residual connections
kwargs = {} if not is_t2i else {"down_block_res": down_block_additional_residuals.pop()}
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
**kwargs,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
if down_block_additional_residuals is not None:
if is_controlnet:
new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip(
@@ -613,7 +618,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
cross_attention_kwargs=cross_attention_kwargs,
)
if mid_block_additional_residual is not None:
if is_controlnet:
sample = sample + mid_block_additional_residual
# 5. up