Compare commits

...

1 Commits

Author SHA1 Message Date
Aryan
a330fe01d7 update 2025-06-06 11:16:39 +02:00

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
@@ -6,9 +7,8 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention
from ..embeddings import TimestepEmbedding, Timesteps
@@ -17,6 +17,29 @@ from ..embeddings import TimestepEmbedding, Timesteps
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class HiDreamImageModelOutput(BaseOutput):
sample: torch.Tensor
double_blocks_auxiliary_loss: Optional[Tuple[torch.Tensor, ...]] = None
single_blocks_auxiliary_loss: Optional[Tuple[torch.Tensor, ...]] = None
class AddAuxiliaryLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, x, loss):
assert loss.numel() == 1
ctx.dtype = loss.dtype
ctx.required_aux_loss = loss.requires_grad
return x
@staticmethod
def backward(ctx, grad_output):
grad_loss = None
if ctx.required_aux_loss:
grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
return grad_output, grad_loss
class HiDreamImageFeedForwardSwiGLU(nn.Module):
def __init__(
self,
@@ -332,7 +355,6 @@ class MoEGate(nn.Module):
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
@@ -379,11 +401,11 @@ class MOEFeedForwardSwiGLU(nn.Module):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape).to(dtype=wtype)
# y = AddAuxiliaryLoss.apply(y, aux_loss)
y = AddAuxiliaryLoss.apply(y, aux_loss)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
y = y + self.shared_experts(identity)
return y
return y, aux_loss
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
@@ -481,9 +503,10 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
# 2. Feed-forward
norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype))
ff_output_i, aux_loss = self.ff_i(norm_hidden_states.to(dtype=wtype))
ff_output_i = gate_mlp_i * ff_output_i
hidden_states = ff_output_i + hidden_states
return hidden_states
return hidden_states, aux_loss
@maybe_allow_in_graph
@@ -573,11 +596,12 @@ class HiDreamImageTransformerBlock(nn.Module):
norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(dtype=wtype)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t
ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states)
ff_output_i, aux_loss = self.ff_i(norm_hidden_states)
ff_output_i = gate_mlp_i * ff_output_i
ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states)
hidden_states = ff_output_i + hidden_states
encoder_hidden_states = ff_output_t + encoder_hidden_states
return hidden_states, encoder_hidden_states
return hidden_states, encoder_hidden_states, aux_loss
class HiDreamBlock(nn.Module):
@@ -785,6 +809,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
hidden_states_masks: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
return_auxiliary_loss: bool = False,
**kwargs,
):
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
@@ -866,15 +891,19 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
# 2. Blocks
block_id = 0
double_blocks_aux_losses = []
single_blocks_aux_losses = []
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
for bid, block in enumerate(self.double_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
cur_encoder_hidden_states = torch.cat(
[initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
hidden_states, initial_encoder_hidden_states, aux_loss = self._gradient_checkpointing_func(
block,
hidden_states,
hidden_states_masks,
@@ -883,7 +912,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
image_rotary_emb,
)
else:
hidden_states, initial_encoder_hidden_states = block(
hidden_states, initial_encoder_hidden_states, aux_loss = block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
encoder_hidden_states=cur_encoder_hidden_states,
@@ -891,6 +920,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
image_rotary_emb=image_rotary_emb,
)
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
double_blocks_aux_losses.append(aux_loss)
block_id += 1
image_tokens_seq_len = hidden_states.shape[1]
@@ -908,7 +938,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
hidden_states, aux_loss = self._gradient_checkpointing_func(
block,
hidden_states,
hidden_states_masks,
@@ -917,7 +947,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
image_rotary_emb,
)
else:
hidden_states = block(
hidden_states, aux_loss = block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
encoder_hidden_states=None,
@@ -925,6 +955,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states[:, :hidden_states_seq_len]
single_blocks_aux_losses.append(aux_loss)
block_id += 1
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
@@ -938,5 +969,13 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
return_values = (output,)
if return_auxiliary_loss:
return_values += (double_blocks_aux_losses, single_blocks_aux_losses)
return return_values
return HiDreamImageModelOutput(
sample=output,
double_blocks_auxiliary_loss=double_blocks_aux_losses,
single_blocks_auxiliary_loss=single_blocks_aux_losses,
)