mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-25 05:44:52 +08:00
Compare commits
1 Commits
remove-unn
...
hidream-ex
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a330fe01d7 |
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -6,9 +7,8 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.modeling_outputs import Transformer2DModelOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
@@ -17,6 +17,29 @@ from ..embeddings import TimestepEmbedding, Timesteps
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class HiDreamImageModelOutput(BaseOutput):
|
||||
sample: torch.Tensor
|
||||
double_blocks_auxiliary_loss: Optional[Tuple[torch.Tensor, ...]] = None
|
||||
single_blocks_auxiliary_loss: Optional[Tuple[torch.Tensor, ...]] = None
|
||||
|
||||
|
||||
class AddAuxiliaryLoss(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, loss):
|
||||
assert loss.numel() == 1
|
||||
ctx.dtype = loss.dtype
|
||||
ctx.required_aux_loss = loss.requires_grad
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_loss = None
|
||||
if ctx.required_aux_loss:
|
||||
grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
|
||||
return grad_output, grad_loss
|
||||
|
||||
|
||||
class HiDreamImageFeedForwardSwiGLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -332,7 +355,6 @@ class MoEGate(nn.Module):
|
||||
else:
|
||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||
ce = mask_ce.float().mean(0)
|
||||
|
||||
Pi = scores_for_aux.mean(0)
|
||||
fi = ce * self.n_routed_experts
|
||||
aux_loss = (Pi * fi).sum() * self.alpha
|
||||
@@ -379,11 +401,11 @@ class MOEFeedForwardSwiGLU(nn.Module):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape).to(dtype=wtype)
|
||||
# y = AddAuxiliaryLoss.apply(y, aux_loss)
|
||||
y = AddAuxiliaryLoss.apply(y, aux_loss)
|
||||
else:
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
y = y + self.shared_experts(identity)
|
||||
return y
|
||||
return y, aux_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
@@ -481,9 +503,10 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
# 2. Feed-forward
|
||||
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
|
||||
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype))
|
||||
ff_output_i, aux_loss = self.ff_i(norm_hidden_states.to(dtype=wtype))
|
||||
ff_output_i = gate_mlp_i * ff_output_i
|
||||
hidden_states = ff_output_i + hidden_states
|
||||
return hidden_states
|
||||
return hidden_states, aux_loss
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@@ -573,11 +596,12 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(dtype=wtype)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t
|
||||
|
||||
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states)
|
||||
ff_output_i, aux_loss = self.ff_i(norm_hidden_states)
|
||||
ff_output_i = gate_mlp_i * ff_output_i
|
||||
ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states)
|
||||
hidden_states = ff_output_i + hidden_states
|
||||
encoder_hidden_states = ff_output_t + encoder_hidden_states
|
||||
return hidden_states, encoder_hidden_states
|
||||
return hidden_states, encoder_hidden_states, aux_loss
|
||||
|
||||
|
||||
class HiDreamBlock(nn.Module):
|
||||
@@ -785,6 +809,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
hidden_states_masks: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
return_auxiliary_loss: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
@@ -866,15 +891,19 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
|
||||
# 2. Blocks
|
||||
block_id = 0
|
||||
double_blocks_aux_losses = []
|
||||
single_blocks_aux_losses = []
|
||||
|
||||
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
|
||||
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
|
||||
|
||||
for bid, block in enumerate(self.double_stream_blocks):
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||
cur_encoder_hidden_states = torch.cat(
|
||||
[initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1
|
||||
)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
hidden_states, initial_encoder_hidden_states, aux_loss = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
hidden_states_masks,
|
||||
@@ -883,7 +912,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
image_rotary_emb,
|
||||
)
|
||||
else:
|
||||
hidden_states, initial_encoder_hidden_states = block(
|
||||
hidden_states, initial_encoder_hidden_states, aux_loss = block(
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_masks=hidden_states_masks,
|
||||
encoder_hidden_states=cur_encoder_hidden_states,
|
||||
@@ -891,6 +920,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
||||
double_blocks_aux_losses.append(aux_loss)
|
||||
block_id += 1
|
||||
|
||||
image_tokens_seq_len = hidden_states.shape[1]
|
||||
@@ -908,7 +938,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
hidden_states, aux_loss = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
hidden_states_masks,
|
||||
@@ -917,7 +947,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
image_rotary_emb,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states, aux_loss = block(
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_masks=hidden_states_masks,
|
||||
encoder_hidden_states=None,
|
||||
@@ -925,6 +955,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
||||
single_blocks_aux_losses.append(aux_loss)
|
||||
block_id += 1
|
||||
|
||||
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
|
||||
@@ -938,5 +969,13 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
return_values = (output,)
|
||||
if return_auxiliary_loss:
|
||||
return_values += (double_blocks_aux_losses, single_blocks_aux_losses)
|
||||
return return_values
|
||||
|
||||
return HiDreamImageModelOutput(
|
||||
sample=output,
|
||||
double_blocks_auxiliary_loss=double_blocks_aux_losses,
|
||||
single_blocks_auxiliary_loss=single_blocks_aux_losses,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user