Add: Support for multiple hidden layers in Eagle3 (#26164)
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
@@ -22,6 +22,10 @@ from vllm.model_executor.models.interfaces import supports_eagle3
|
||||
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
|
||||
id="qwen3-eagle3-speculator-w4a16-verifier",
|
||||
),
|
||||
pytest.param(
|
||||
"nm-testing/random-weights-llama3.1.8b-2layer-eagle3",
|
||||
id="llama3-eagl3-multiple-layers",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_eagle3_speculators_model(
|
||||
|
||||
@@ -34,15 +34,20 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
config: Optional[LlamaConfig] = None,
|
||||
layer_idx: int = 0,
|
||||
) -> None:
|
||||
super().__init__(vllm_config, prefix=prefix, config=config)
|
||||
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
quant_config = self.get_quant_config(vllm_config)
|
||||
|
||||
# First layer uses 2*hidden_size (embeds + hidden_states concatenated)
|
||||
# Subsequent layers use hidden_size (only hidden_states, no embeds)
|
||||
qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size
|
||||
|
||||
# override qkv
|
||||
self.self_attn.qkv_proj = QKVParallelLinear(
|
||||
2 * self.hidden_size,
|
||||
qkv_input_size,
|
||||
self.self_attn.head_dim,
|
||||
self.self_attn.total_num_heads,
|
||||
self.self_attn.total_num_kv_heads,
|
||||
@@ -52,6 +57,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
)
|
||||
|
||||
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
if getattr(config, "norm_before_residual", False):
|
||||
self._residual_norm = self._norm_before_residual
|
||||
@@ -90,11 +96,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
embeds = self.input_layernorm(embeds)
|
||||
if self.layer_idx == 0:
|
||||
# First layer: concatenate embeds with hidden_states
|
||||
embeds = self.input_layernorm(embeds)
|
||||
hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
|
||||
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
|
||||
else:
|
||||
# Subsequent layers: process hidden_states and residuals only
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
|
||||
|
||||
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
@@ -133,9 +143,11 @@ class LlamaModel(nn.Module):
|
||||
[
|
||||
LlamaDecoderLayer(
|
||||
current_vllm_config,
|
||||
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
|
||||
prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"),
|
||||
config=self.config,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
for layer_idx in range(self.config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
if hasattr(self.config, "target_hidden_size"):
|
||||
@@ -166,13 +178,13 @@ class LlamaModel(nn.Module):
|
||||
assert hidden_states.shape[-1] == input_embeds.shape[-1]
|
||||
|
||||
residual = None
|
||||
hidden_states, residual = self.layers[0](
|
||||
positions,
|
||||
input_embeds,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
embeds=input_embeds,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
)
|
||||
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
|
||||
return hidden_states, hidden_prenorm
|
||||
|
||||
|
||||
Reference in New Issue
Block a user