Add: Support for multiple hidden layers in Eagle3 (#26164)

Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
Rahul Tuli
2025-10-09 13:00:50 +05:30
committed by GitHub
parent b960441812
commit cf4cd6c24f
2 changed files with 29 additions and 13 deletions

View File

@@ -22,6 +22,10 @@ from vllm.model_executor.models.interfaces import supports_eagle3
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
id="qwen3-eagle3-speculator-w4a16-verifier",
),
pytest.param(
"nm-testing/random-weights-llama3.1.8b-2layer-eagle3",
id="llama3-eagl3-multiple-layers",
),
],
)
def test_eagle3_speculators_model(

View File

@@ -34,15 +34,20 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
vllm_config: VllmConfig,
prefix: str = "",
config: Optional[LlamaConfig] = None,
layer_idx: int = 0,
) -> None:
super().__init__(vllm_config, prefix=prefix, config=config)
config = config or vllm_config.model_config.hf_config
quant_config = self.get_quant_config(vllm_config)
# First layer uses 2*hidden_size (embeds + hidden_states concatenated)
# Subsequent layers use hidden_size (only hidden_states, no embeds)
qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size
# override qkv
self.self_attn.qkv_proj = QKVParallelLinear(
2 * self.hidden_size,
qkv_input_size,
self.self_attn.head_dim,
self.self_attn.total_num_heads,
self.self_attn.total_num_kv_heads,
@@ -52,6 +57,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
)
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.layer_idx = layer_idx
if getattr(config, "norm_before_residual", False):
self._residual_norm = self._norm_before_residual
@@ -90,11 +96,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
embeds = self.input_layernorm(embeds)
if self.layer_idx == 0:
# First layer: concatenate embeds with hidden_states
embeds = self.input_layernorm(embeds)
hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
else:
# Subsequent layers: process hidden_states and residuals only
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
@@ -133,9 +143,11 @@ class LlamaModel(nn.Module):
[
LlamaDecoderLayer(
current_vllm_config,
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"),
config=self.config,
layer_idx=layer_idx,
)
for layer_idx in range(self.config.num_hidden_layers)
]
)
if hasattr(self.config, "target_hidden_size"):
@@ -166,13 +178,13 @@ class LlamaModel(nn.Module):
assert hidden_states.shape[-1] == input_embeds.shape[-1]
residual = None
hidden_states, residual = self.layers[0](
positions,
input_embeds,
hidden_states,
residual,
)
for layer in self.layers:
hidden_states, residual = layer(
positions=positions,
embeds=input_embeds,
hidden_states=hidden_states,
residual=residual,
)
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
return hidden_states, hidden_prenorm