mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-19 15:08:09 +08:00
Compare commits
2 Commits
misc-docs-
...
t2i_adapte
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
23d616de52 | ||
|
|
cfc129d669 |
@@ -808,7 +808,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
|||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, down_block_res=None, cross_attention_kwargs=None
|
||||||
):
|
):
|
||||||
# TODO(Patrick, William) - attention mask is not used
|
# TODO(Patrick, William) - attention mask is not used
|
||||||
output_states = ()
|
output_states = ()
|
||||||
@@ -843,6 +843,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
|||||||
output_states += (hidden_states,)
|
output_states += (hidden_states,)
|
||||||
|
|
||||||
if self.downsamplers is not None:
|
if self.downsamplers is not None:
|
||||||
|
if down_block_res is not None:
|
||||||
|
hidden_states += down_block_res
|
||||||
for downsampler in self.downsamplers:
|
for downsampler in self.downsamplers:
|
||||||
hidden_states = downsampler(hidden_states)
|
hidden_states = downsampler(hidden_states)
|
||||||
|
|
||||||
|
|||||||
@@ -576,23 +576,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|||||||
# 2. pre-process
|
# 2. pre-process
|
||||||
sample = self.conv_in(sample)
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
|
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
||||||
|
is_t2i = mid_block_additional_residual is None and down_block_additional_residuals is not None
|
||||||
# 3. down
|
# 3. down
|
||||||
down_block_res_samples = (sample,)
|
down_block_res_samples = (sample,)
|
||||||
for downsample_block in self.down_blocks:
|
for downsample_block in self.down_blocks:
|
||||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||||
|
# find out whether `is_t2i` depending on the shape of the residual connections
|
||||||
|
kwargs = {} if not is_t2i else {"down_block_res": down_block_additional_residuals.pop()}
|
||||||
sample, res_samples = downsample_block(
|
sample, res_samples = downsample_block(
|
||||||
hidden_states=sample,
|
hidden_states=sample,
|
||||||
temb=emb,
|
temb=emb,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||||
|
|
||||||
down_block_res_samples += res_samples
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
if down_block_additional_residuals is not None:
|
if is_controlnet:
|
||||||
new_down_block_res_samples = ()
|
new_down_block_res_samples = ()
|
||||||
|
|
||||||
for down_block_res_sample, down_block_additional_residual in zip(
|
for down_block_res_sample, down_block_additional_residual in zip(
|
||||||
@@ -613,7 +618,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if mid_block_additional_residual is not None:
|
if is_controlnet:
|
||||||
sample = sample + mid_block_additional_residual
|
sample = sample + mid_block_additional_residual
|
||||||
|
|
||||||
# 5. up
|
# 5. up
|
||||||
|
|||||||
Reference in New Issue
Block a user