Refactor Transformers backend to use mixins (#26906)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
@@ -57,7 +57,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
|||||||
/tests/v1/offloading @ApostaC
|
/tests/v1/offloading @ApostaC
|
||||||
|
|
||||||
# Transformers backend
|
# Transformers backend
|
||||||
/vllm/model_executor/models/transformers.py @hmellor
|
/vllm/model_executor/models/transformers @hmellor
|
||||||
/tests/models/test_transformers.py @hmellor
|
/tests/models/test_transformers.py @hmellor
|
||||||
|
|
||||||
# Docs
|
# Docs
|
||||||
|
|||||||
@@ -912,11 +912,11 @@ _TRANSFORMERS_BACKEND_MODELS = {
|
|||||||
"TransformersForCausalLM": _HfExamplesInfo(
|
"TransformersForCausalLM": _HfExamplesInfo(
|
||||||
"hmellor/Ilama-3.2-1B", trust_remote_code=True
|
"hmellor/Ilama-3.2-1B", trust_remote_code=True
|
||||||
),
|
),
|
||||||
"TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
|
"TransformersMultiModalForCausalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
|
||||||
"TransformersMoEForCausalLM": _HfExamplesInfo(
|
"TransformersMoEForCausalLM": _HfExamplesInfo(
|
||||||
"allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"
|
"allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"
|
||||||
),
|
),
|
||||||
"TransformersMoEForMultimodalLM": _HfExamplesInfo(
|
"TransformersMultiModalMoEForCausalLM": _HfExamplesInfo(
|
||||||
"Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0"
|
"Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0"
|
||||||
),
|
),
|
||||||
"TransformersMoEEmbeddingModel": _HfExamplesInfo(
|
"TransformersMoEEmbeddingModel": _HfExamplesInfo(
|
||||||
@@ -925,6 +925,10 @@ _TRANSFORMERS_BACKEND_MODELS = {
|
|||||||
"TransformersMoEForSequenceClassification": _HfExamplesInfo(
|
"TransformersMoEForSequenceClassification": _HfExamplesInfo(
|
||||||
"Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"
|
"Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"
|
||||||
),
|
),
|
||||||
|
"TransformersMultiModalEmbeddingModel": _HfExamplesInfo("google/gemma-3-4b-it"),
|
||||||
|
"TransformersMultiModalForSequenceClassification": _HfExamplesInfo(
|
||||||
|
"google/gemma-3-4b-it"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
_EXAMPLE_MODELS = {
|
_EXAMPLE_MODELS = {
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ MINIMAL_MODEL_ARCH_LIST = [
|
|||||||
"JinaVLForRanking",
|
"JinaVLForRanking",
|
||||||
"InternVLChatModel",
|
"InternVLChatModel",
|
||||||
"InternLM2ForRewardModel",
|
"InternLM2ForRewardModel",
|
||||||
"TransformersForMultimodalLM",
|
"TransformersMultiModalForCausalLM",
|
||||||
"PrithviGeoSpatialMAE",
|
"PrithviGeoSpatialMAE",
|
||||||
"UltravoxModel",
|
"UltravoxModel",
|
||||||
"DeepSeekMTPModel",
|
"DeepSeekMTPModel",
|
||||||
|
|||||||
@@ -211,11 +211,7 @@ def test_embed_loading(vllm_runner, model):
|
|||||||
def test_pooling(hf_runner, vllm_runner, example_prompts, arch):
|
def test_pooling(hf_runner, vllm_runner, example_prompts, arch):
|
||||||
model = get_model(arch)
|
model = get_model(arch)
|
||||||
|
|
||||||
vllm_kwargs = dict(
|
vllm_kwargs = dict(max_model_len=None, model_impl="transformers")
|
||||||
max_model_len=None,
|
|
||||||
model_impl="transformers",
|
|
||||||
compilation_config=dict(cudagraph_capture_sizes=[8]),
|
|
||||||
)
|
|
||||||
|
|
||||||
hf_kwargs = dict()
|
hf_kwargs = dict()
|
||||||
if arch == "TransformersEmbeddingModel":
|
if arch == "TransformersEmbeddingModel":
|
||||||
|
|||||||
@@ -147,6 +147,10 @@ class ModelConfig:
|
|||||||
seed: int | None = None
|
seed: int | None = None
|
||||||
"""Random seed for reproducibility. Initialized to None in V0, but
|
"""Random seed for reproducibility. Initialized to None in V0, but
|
||||||
initialized to 0 in V1."""
|
initialized to 0 in V1."""
|
||||||
|
hf_config: PretrainedConfig = field(init=False)
|
||||||
|
"""The Hugging Face config of the model."""
|
||||||
|
hf_text_config: PretrainedConfig = field(init=False)
|
||||||
|
"""The Hugging Face config of the text model (same as hf_config for text models)."""
|
||||||
hf_config_path: str | None = None
|
hf_config_path: str | None = None
|
||||||
"""Name or path of the Hugging Face config to use. If unspecified, model
|
"""Name or path of the Hugging Face config to use. If unspecified, model
|
||||||
name or path will be used."""
|
name or path will be used."""
|
||||||
@@ -771,8 +775,10 @@ class ModelConfig:
|
|||||||
def _get_transformers_backend_cls(self) -> str:
|
def _get_transformers_backend_cls(self) -> str:
|
||||||
"""Determine which Transformers backend class will be used if
|
"""Determine which Transformers backend class will be used if
|
||||||
`model_impl` is set to `transformers` or `auto`."""
|
`model_impl` is set to `transformers` or `auto`."""
|
||||||
prefix = "Transformers"
|
cls = "Transformers"
|
||||||
prefix += "MoE" if self.get_num_experts() > 1 else ""
|
# If 'hf_config != hf_text_config' it's a nested config, i.e. multimodal
|
||||||
|
cls += "MultiModal" if self.hf_config != self.hf_text_config else ""
|
||||||
|
cls += "MoE" if self.get_num_experts() > 1 else ""
|
||||||
# Check if the architecture we're wrapping has defaults
|
# Check if the architecture we're wrapping has defaults
|
||||||
runner = None
|
runner = None
|
||||||
convert = None
|
convert = None
|
||||||
@@ -788,18 +794,15 @@ class ModelConfig:
|
|||||||
runner = "generate"
|
runner = "generate"
|
||||||
if convert in {None, "none"}:
|
if convert in {None, "none"}:
|
||||||
convert = "embed"
|
convert = "embed"
|
||||||
# Resolve Transformers backend pooling classes
|
# Resolve Transformers backend task
|
||||||
if runner == "pooling":
|
if runner == "pooling":
|
||||||
if convert == "embed":
|
if convert == "embed":
|
||||||
return prefix + "EmbeddingModel"
|
return cls + "EmbeddingModel"
|
||||||
if convert == "classify":
|
if convert == "classify":
|
||||||
return prefix + "ForSequenceClassification"
|
return cls + "ForSequenceClassification"
|
||||||
# Resolve Transformers backend generate classes
|
else:
|
||||||
if self.hf_config != self.hf_text_config:
|
cls += "ForCausalLM"
|
||||||
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
|
return cls
|
||||||
# probably a composite config, i.e. multimodal
|
|
||||||
return prefix + "ForMultimodalLM"
|
|
||||||
return prefix + "ForCausalLM"
|
|
||||||
|
|
||||||
def using_transformers_backend(self) -> bool:
|
def using_transformers_backend(self) -> bool:
|
||||||
"""Check if the model is using the Transformers backend class."""
|
"""Check if the model is using the Transformers backend class."""
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from vllm.config.multimodal import BaseDummyOptions
|
|||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
from vllm.model_executor.models.transformers import replace_linear_class
|
from vllm.model_executor.models.transformers.utils import replace_linear_class
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (
|
from vllm.multimodal.inputs import (
|
||||||
MultiModalDataDict,
|
MultiModalDataDict,
|
||||||
|
|||||||
@@ -401,32 +401,44 @@ _TRANSFORMERS_SUPPORTED_MODELS = {
|
|||||||
# Text generation models
|
# Text generation models
|
||||||
"SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
|
"SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
|
||||||
# Multimodal models
|
# Multimodal models
|
||||||
"Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
|
"Emu3ForConditionalGeneration": (
|
||||||
|
"transformers",
|
||||||
|
"TransformersMultiModalForCausalLM",
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
_TRANSFORMERS_BACKEND_MODELS = {
|
_TRANSFORMERS_BACKEND_MODELS = {
|
||||||
|
# Text generation models
|
||||||
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
|
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
|
||||||
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
|
"TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
|
||||||
"TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"), # noqa: E501
|
# Multimodal models
|
||||||
"TransformersMoEForMultimodalLM": (
|
"TransformersMultiModalForCausalLM": (
|
||||||
"transformers_moe",
|
"transformers",
|
||||||
"TransformersMoEForMultimodalLM",
|
"TransformersMultiModalForCausalLM",
|
||||||
),
|
),
|
||||||
"TransformersEmbeddingModel": (
|
"TransformersMultiModalMoEForCausalLM": (
|
||||||
"transformers_pooling",
|
"transformers",
|
||||||
"TransformersEmbeddingModel",
|
"TransformersMultiModalMoEForCausalLM",
|
||||||
),
|
),
|
||||||
|
# Embedding models
|
||||||
|
"TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
|
||||||
|
"TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
|
||||||
|
"TransformersMultiModalEmbeddingModel": (
|
||||||
|
"transformers",
|
||||||
|
"TransformersMultiModalEmbeddingModel",
|
||||||
|
),
|
||||||
|
# Sequence classification models
|
||||||
"TransformersForSequenceClassification": (
|
"TransformersForSequenceClassification": (
|
||||||
"transformers_pooling",
|
"transformers",
|
||||||
"TransformersForSequenceClassification",
|
"TransformersForSequenceClassification",
|
||||||
),
|
),
|
||||||
"TransformersMoEForSequenceClassification": (
|
"TransformersMoEForSequenceClassification": (
|
||||||
"transformers_pooling",
|
"transformers",
|
||||||
"TransformersMoEForSequenceClassification",
|
"TransformersMoEForSequenceClassification",
|
||||||
),
|
),
|
||||||
"TransformersMoEEmbeddingModel": (
|
"TransformersMultiModalForSequenceClassification": (
|
||||||
"transformers_pooling",
|
"transformers",
|
||||||
"TransformersMoEEmbeddingModel",
|
"TransformersMultiModalForSequenceClassification",
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,961 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
# Copyright 2024 The vLLM team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Wrapper around `transformers` models"""
|
|
||||||
|
|
||||||
from collections.abc import Iterable, Mapping
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import regex as re
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from packaging.version import Version
|
|
||||||
from torch import nn
|
|
||||||
from transformers import AutoModel, BatchFeature, PretrainedConfig, PreTrainedModel
|
|
||||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionType
|
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
|
||||||
from vllm.config import (
|
|
||||||
CacheConfig,
|
|
||||||
DeviceConfig,
|
|
||||||
ModelConfig,
|
|
||||||
ParallelConfig,
|
|
||||||
VllmConfig,
|
|
||||||
)
|
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
|
||||||
from vllm.config.utils import getattr_iter
|
|
||||||
from vllm.distributed import get_pp_group, get_tp_group
|
|
||||||
from vllm.distributed.utils import get_pp_indices
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
|
||||||
from vllm.model_executor.layers.linear import (
|
|
||||||
ColumnParallelLinear,
|
|
||||||
ReplicatedLinear,
|
|
||||||
RowParallelLinear,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
||||||
ParallelLMHead,
|
|
||||||
VocabParallelEmbedding,
|
|
||||||
)
|
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
|
|
||||||
from vllm.multimodal.inputs import (
|
|
||||||
MultiModalDataDict,
|
|
||||||
MultiModalFieldConfig,
|
|
||||||
MultiModalInputs,
|
|
||||||
MultiModalUUIDDict,
|
|
||||||
PlaceholderRange,
|
|
||||||
)
|
|
||||||
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
|
|
||||||
from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo
|
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
|
||||||
from vllm.sequence import IntermediateTensors
|
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant
|
|
||||||
from .utils import (
|
|
||||||
AutoWeightsLoader,
|
|
||||||
PPMissingLayer,
|
|
||||||
WeightsMapper,
|
|
||||||
make_empty_intermediate_tensors_factory,
|
|
||||||
maybe_prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def get_feature_request_tip(
|
|
||||||
model: str,
|
|
||||||
trust_remote_code: bool,
|
|
||||||
) -> str:
|
|
||||||
hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new"
|
|
||||||
gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose"
|
|
||||||
url = hf_url if trust_remote_code else gh_url
|
|
||||||
prefix = f"Please open {url} to request support for this feature. "
|
|
||||||
if Path(model).exists():
|
|
||||||
prefix = ""
|
|
||||||
doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models"
|
|
||||||
tip = f"See {doc_url} for instructions on how to add support yourself."
|
|
||||||
return f"{prefix}{tip}"
|
|
||||||
|
|
||||||
|
|
||||||
def vllm_flash_attention_forward(
|
|
||||||
# Transformers args
|
|
||||||
module: torch.nn.Module,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
# Transformers kwargs
|
|
||||||
scaling: float | None = None,
|
|
||||||
# vLLM kwargs
|
|
||||||
attention_instances: dict[Attention] | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self_attn = attention_instances[module.layer_idx]
|
|
||||||
if scaling is not None:
|
|
||||||
self_attn.impl.scale = float(scaling)
|
|
||||||
hidden = query.shape[-2]
|
|
||||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
|
||||||
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
|
|
||||||
return self_attn.forward(query, key, value), None
|
|
||||||
|
|
||||||
|
|
||||||
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
|
|
||||||
|
|
||||||
|
|
||||||
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
|
|
||||||
logger.debug("%s: %s -> %s", name, old_module, new_module)
|
|
||||||
|
|
||||||
|
|
||||||
def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
|
|
||||||
"""
|
|
||||||
Callable to be passed to `@support_torch_compile`'s `enable_if` argument.
|
|
||||||
|
|
||||||
Defaults to `True` but is disabled in the following situations:
|
|
||||||
|
|
||||||
- The model uses dynamic rope scaling.
|
|
||||||
"""
|
|
||||||
enable = True
|
|
||||||
text_config = vllm_config.model_config.hf_config.get_text_config()
|
|
||||||
# Dynamic rope scaling is not compatible with torch.compile
|
|
||||||
rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
|
|
||||||
if rope_scaling.get("rope_type") == "dynamic":
|
|
||||||
enable = False
|
|
||||||
return enable
|
|
||||||
|
|
||||||
|
|
||||||
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
|
|
||||||
|
|
||||||
|
|
||||||
def replace_linear_class(
|
|
||||||
linear: nn.Linear,
|
|
||||||
style: Style = "replicate",
|
|
||||||
quant_config: QuantizationConfig | None = None,
|
|
||||||
*,
|
|
||||||
prefix: str = "",
|
|
||||||
) -> ColumnParallelLinear | RowParallelLinear | ReplicatedLinear:
|
|
||||||
"""
|
|
||||||
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
linear: `nn.Linear` to be replaced.
|
|
||||||
style: Tensor parallel style of the new linear, e.g. "colwise".
|
|
||||||
quant_config: Quantization config for the new linear.
|
|
||||||
Returns:
|
|
||||||
The new linear.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not isinstance(style, str):
|
|
||||||
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
|
|
||||||
|
|
||||||
vllm_linear_cls, vllm_linear_kwargs = {
|
|
||||||
"colwise": (ColumnParallelLinear, {}),
|
|
||||||
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
|
|
||||||
"rowwise": (RowParallelLinear, {}),
|
|
||||||
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
|
|
||||||
"replicate": (ReplicatedLinear, {}),
|
|
||||||
}.get(style, (ReplicatedLinear, {}))
|
|
||||||
|
|
||||||
return vllm_linear_cls(
|
|
||||||
input_size=linear.in_features,
|
|
||||||
output_size=linear.out_features,
|
|
||||||
bias=linear.bias is not None,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=prefix,
|
|
||||||
return_bias=False,
|
|
||||||
**vllm_linear_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
|
|
||||||
"""Replace a Transformers RMSNorm with vLLM's RMSNorm.
|
|
||||||
|
|
||||||
This method assumes:
|
|
||||||
- Weight is stored as `weight`.
|
|
||||||
- Epsilon is stored as `eps` or `variance_epsilon`.
|
|
||||||
- `with_scale` indicates whether the layer has a weight (Gemma3n only).
|
|
||||||
- `var_hidden_size` is only ever used for Intern vision encoder in vLLM
|
|
||||||
and Transformers doesn't appear to have the same concept.
|
|
||||||
"""
|
|
||||||
eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6)
|
|
||||||
kwargs = {"hidden_size": hidden_size, "eps": eps}
|
|
||||||
# Update hidden size if weight is available
|
|
||||||
weight_meta = getattr(rms_norm, "weight", None)
|
|
||||||
if weight_meta is not None:
|
|
||||||
kwargs["hidden_size"] = weight_meta.size(0)
|
|
||||||
# Check if weight is all zeros, which indicates GemmaRMSNorm
|
|
||||||
# We must create a new instance because rms_norm is on meta
|
|
||||||
try:
|
|
||||||
with torch.device("cpu"):
|
|
||||||
weight_test = getattr(rms_norm.__class__(1), "weight", None)
|
|
||||||
except Exception:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to determine if RMSNorm weight is centered on zero or one. "
|
|
||||||
"Defaulting to one."
|
|
||||||
)
|
|
||||||
weight_test = None
|
|
||||||
if weight_test is not None and torch.all(weight_test == 0):
|
|
||||||
return GemmaRMSNorm(**kwargs)
|
|
||||||
# Otherwise assume it's a regular RMSNorm
|
|
||||||
kwargs["has_weight"] = getattr(rms_norm, "with_scale", True)
|
|
||||||
if weight_meta is not None:
|
|
||||||
kwargs["dtype"] = weight_meta.dtype
|
|
||||||
else:
|
|
||||||
# No weight, fall back to weightless RMSNorm
|
|
||||||
kwargs["has_weight"] = False
|
|
||||||
return RMSNorm(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from `accelerate`
|
|
||||||
@contextmanager
|
|
||||||
def init_on_device_without_buffers(device: torch.device):
|
|
||||||
"""
|
|
||||||
A context manager under which models are initialized with all
|
|
||||||
parameters on the specified device. However buffers are not
|
|
||||||
initialized on specified device.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
device (`torch.device`):
|
|
||||||
Device to initialize all parameters on.
|
|
||||||
"""
|
|
||||||
|
|
||||||
old_register_parameter = nn.Module.register_parameter
|
|
||||||
|
|
||||||
def register_empty_parameter(module, name, param):
|
|
||||||
old_register_parameter(module, name, param)
|
|
||||||
if param is not None:
|
|
||||||
param_cls = type(module._parameters[name])
|
|
||||||
kwargs = module._parameters[name].__dict__
|
|
||||||
kwargs["requires_grad"] = param.requires_grad
|
|
||||||
module._parameters[name] = param_cls(
|
|
||||||
module._parameters[name].to(device), **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
tensor_constructors_to_patch = {}
|
|
||||||
|
|
||||||
def patch_tensor_constructor(fn):
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
kwargs["device"] = device
|
|
||||||
return fn(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
try:
|
|
||||||
nn.Module.register_parameter = register_empty_parameter
|
|
||||||
for torch_function_name in tensor_constructors_to_patch:
|
|
||||||
setattr(
|
|
||||||
torch,
|
|
||||||
torch_function_name,
|
|
||||||
patch_tensor_constructor(getattr(torch, torch_function_name)),
|
|
||||||
)
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
nn.Module.register_parameter = old_register_parameter
|
|
||||||
for (
|
|
||||||
torch_function_name,
|
|
||||||
old_torch_function,
|
|
||||||
) in tensor_constructors_to_patch.items():
|
|
||||||
setattr(torch, torch_function_name, old_torch_function)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiModalProcessingInfo(BaseProcessingInfo):
|
|
||||||
def get_supported_mm_limits(self):
|
|
||||||
return {"image": None}
|
|
||||||
|
|
||||||
def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
|
|
||||||
return {"image": self.get_max_image_tokens()}
|
|
||||||
|
|
||||||
def get_max_image_tokens(self) -> int:
|
|
||||||
width, height = self.get_max_image_size()
|
|
||||||
processor = self.get_hf_processor()
|
|
||||||
multimodal_config = self.ctx.model_config.multimodal_config
|
|
||||||
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
|
|
||||||
mm_tokens = processor._get_num_multimodal_tokens(
|
|
||||||
image_sizes=([height, width],), **mm_processor_kwargs
|
|
||||||
)
|
|
||||||
image_tokens = mm_tokens["num_image_tokens"][0]
|
|
||||||
return image_tokens
|
|
||||||
|
|
||||||
def get_max_image_size(self):
|
|
||||||
return 10_000, 10_000 # hardcode for arbitrary very large size
|
|
||||||
|
|
||||||
|
|
||||||
class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]):
|
|
||||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
||||||
num_images = mm_counts.get("image", 0)
|
|
||||||
|
|
||||||
processor = self.info.get_hf_processor()
|
|
||||||
if "gemma3" in processor.__class__.__name__.lower():
|
|
||||||
image_token = processor.boi_token
|
|
||||||
else:
|
|
||||||
image_token = getattr(processor, "image_token", "")
|
|
||||||
return image_token * num_images
|
|
||||||
|
|
||||||
def get_dummy_mm_data(
|
|
||||||
self,
|
|
||||||
seq_len: int,
|
|
||||||
mm_counts: Mapping[str, int],
|
|
||||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
|
||||||
) -> MultiModalDataDict:
|
|
||||||
num_images = mm_counts.get("image", 0)
|
|
||||||
|
|
||||||
target_width, target_height = self.info.get_max_image_size()
|
|
||||||
|
|
||||||
image_overrides = mm_options.get("image") if mm_options else None
|
|
||||||
|
|
||||||
return {
|
|
||||||
"image": self._get_dummy_images(
|
|
||||||
width=target_width,
|
|
||||||
height=target_height,
|
|
||||||
num_images=num_images,
|
|
||||||
overrides=image_overrides,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
|
||||||
def _get_prompt_updates(
|
|
||||||
self,
|
|
||||||
mm_items: MultiModalDataItems,
|
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
|
||||||
out_mm_kwargs: MultiModalKwargsItems,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Given the original multi-modal items for this modality
|
|
||||||
and HF-processed data, output the updates to perform.
|
|
||||||
|
|
||||||
The information returned by this method is used to update token inputs
|
|
||||||
which bypass the HF processor. It is also used to update the output of
|
|
||||||
HF processor if the HF process does not apply prompt updates to text
|
|
||||||
inputs.
|
|
||||||
|
|
||||||
Moreover, this information is critical to determine the token positions
|
|
||||||
in order to construct :class:`~vllm-multimodal.input.PlaceholderRange`
|
|
||||||
for each multi-modal item.
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
|
||||||
self,
|
|
||||||
hf_inputs: BatchFeature,
|
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
|
||||||
# HF Processors always return a mask but vLLM doesn't need it
|
|
||||||
hf_inputs.pop("attention_mask", None)
|
|
||||||
num_image_patches = hf_inputs.get("num_image_patches")
|
|
||||||
mm_fields = {
|
|
||||||
key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches)
|
|
||||||
for key in hf_inputs
|
|
||||||
}
|
|
||||||
mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
|
|
||||||
"image", num_image_patches
|
|
||||||
)
|
|
||||||
|
|
||||||
# Keep these as batched, as they always have batch size as first dim
|
|
||||||
mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image")
|
|
||||||
mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image")
|
|
||||||
mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
|
|
||||||
return mm_fields
|
|
||||||
|
|
||||||
def _get_hf_mm_data(
|
|
||||||
self,
|
|
||||||
mm_items: MultiModalDataItems,
|
|
||||||
) -> tuple[Mapping[str, object], Mapping[str, object]]:
|
|
||||||
"""
|
|
||||||
In contrast to the base class, this method always adds
|
|
||||||
`return_mm_token_type_ids` to the processor data
|
|
||||||
"""
|
|
||||||
processor_data, passthrough_data = super()._get_hf_mm_data(mm_items)
|
|
||||||
processor_data["return_mm_token_type_ids"] = True
|
|
||||||
return processor_data, passthrough_data
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
prompt: str | list[int],
|
|
||||||
mm_data: MultiModalDataDict,
|
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
|
||||||
tokenization_kwargs: Mapping[str, object] | None = None,
|
|
||||||
mm_uuids: MultiModalUUIDDict | None = None,
|
|
||||||
) -> MultiModalInputs:
|
|
||||||
"""
|
|
||||||
Process multi-modal inputs to be used in vLLM.
|
|
||||||
|
|
||||||
Apply HF Processor on prompt text and multi-modal data together,
|
|
||||||
outputting token IDs and processed tensors.
|
|
||||||
"""
|
|
||||||
if tokenization_kwargs is None:
|
|
||||||
tokenization_kwargs = {}
|
|
||||||
|
|
||||||
mm_items = self._to_mm_items(mm_data)
|
|
||||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
||||||
if not isinstance(prompt, str):
|
|
||||||
# the prompt is the tokenized ids which is not supported
|
|
||||||
# by the hf_processor, which is why we would need to decode the ids
|
|
||||||
# into string
|
|
||||||
prompt = hf_processor.decode(prompt)
|
|
||||||
|
|
||||||
# Bypass cached processor and always apply to the full set of mm inputs
|
|
||||||
# NOTE: we can't just set caching=False because base class method
|
|
||||||
# transforms outputs to `MultiModalKwargs` which is not going to
|
|
||||||
# work for Transformers. We have a lot of logic tied to
|
|
||||||
# `mm_tokens_per_modality` below
|
|
||||||
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
|
|
||||||
prompt_text=prompt,
|
|
||||||
mm_items=mm_items,
|
|
||||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# For gemma3 we check `token_type_ids` as the key
|
|
||||||
token_type_key = (
|
|
||||||
"mm_token_type_ids"
|
|
||||||
if "mm_token_type_ids" in processed_data
|
|
||||||
else "token_type_ids"
|
|
||||||
)
|
|
||||||
mm_token_type_ids = processed_data.pop(token_type_key)
|
|
||||||
|
|
||||||
# We can infer vLLM style placeholder from token type ids, if we split
|
|
||||||
# it for each input `mm_data`.
|
|
||||||
mm_positions = torch.where(mm_token_type_ids == 1)[1]
|
|
||||||
images = mm_items.get_items("image", ImageProcessorItems)
|
|
||||||
multimodal_config = self.info.ctx.model_config.multimodal_config
|
|
||||||
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
|
|
||||||
image_sizes = []
|
|
||||||
for item_idx in range(len(images)):
|
|
||||||
image_size = images.get_image_size(item_idx)
|
|
||||||
image_sizes.append((image_size.height, image_size.width))
|
|
||||||
|
|
||||||
mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
|
|
||||||
image_sizes=image_sizes, **mm_processor_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
mm_placeholders = {}
|
|
||||||
split_sizes = mm_tokens_per_modality["num_image_tokens"]
|
|
||||||
if split_sizes:
|
|
||||||
chunked_mm_positions = torch.split(mm_positions, split_sizes)
|
|
||||||
mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
|
|
||||||
chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
|
|
||||||
ranges = [
|
|
||||||
PlaceholderRange(
|
|
||||||
offset=positions[0].item(),
|
|
||||||
length=positions.shape[0],
|
|
||||||
is_embed=(mm_tokens == hf_processor.image_token_id).bool(),
|
|
||||||
)
|
|
||||||
for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens)
|
|
||||||
]
|
|
||||||
mm_placeholders = {"image": ranges}
|
|
||||||
|
|
||||||
processed_data["num_image_patches"] = torch.tensor(
|
|
||||||
mm_tokens_per_modality["num_image_patches"]
|
|
||||||
)
|
|
||||||
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
|
|
||||||
processed_data,
|
|
||||||
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use overrides if provided; fallback to data-dependent hashing.
|
|
||||||
mm_hashes = self._hash_mm_items(
|
|
||||||
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
|
|
||||||
)
|
|
||||||
|
|
||||||
return MultiModalInputs(
|
|
||||||
type="multimodal",
|
|
||||||
prompt_token_ids=prompt_ids,
|
|
||||||
mm_kwargs=mm_kwargs,
|
|
||||||
mm_hashes=mm_hashes,
|
|
||||||
mm_placeholders=mm_placeholders,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|
||||||
embedding_padding_modules = ["lm_head"]
|
|
||||||
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
||||||
super().__init__()
|
|
||||||
logger.info("Using Transformers backend.")
|
|
||||||
|
|
||||||
self.config: PretrainedConfig = vllm_config.model_config.hf_config
|
|
||||||
self.text_config: PretrainedConfig = self.config.get_text_config()
|
|
||||||
self.cache_config: CacheConfig = vllm_config.cache_config
|
|
||||||
self.device_config: DeviceConfig = vllm_config.device_config
|
|
||||||
self.model_config: ModelConfig = vllm_config.model_config
|
|
||||||
self.parallel_config: ParallelConfig = vllm_config.parallel_config
|
|
||||||
self.quant_config: QuantizationConfig | None = vllm_config.quant_config
|
|
||||||
|
|
||||||
self.pp_group = get_pp_group()
|
|
||||||
self.tp_group = get_tp_group()
|
|
||||||
|
|
||||||
# Weights to skip in `self.load_weights`
|
|
||||||
self.skip_prefixes: list[str] = []
|
|
||||||
"""Skip loading weights whose qualname starts with these prefixes."""
|
|
||||||
self.skip_substrs: list[str] = []
|
|
||||||
"""Skip loading weights whose qualname contains these substrings."""
|
|
||||||
self.ignore_unexpected_prefixes: list[str] = []
|
|
||||||
"""Ignore unexpected weights whose qualname starts with these prefixes.
|
|
||||||
"""
|
|
||||||
self.ignore_unexpected_suffixes: list[str] = []
|
|
||||||
"""Ignore unexpected weights whose qualname ends with these suffixes."""
|
|
||||||
|
|
||||||
if self.quant_config:
|
|
||||||
quant_method_name = self.quant_config.get_name()
|
|
||||||
# Check for unsupported quantization methods.
|
|
||||||
if quant_method_name == "mxfp4":
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Transformers backend does not support MXFP4 quantization yet."
|
|
||||||
)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if "gptq" in quant_method_name:
|
|
||||||
self.ignore_unexpected_suffixes.append(".bias")
|
|
||||||
|
|
||||||
# Set correct attn and init on "meta" to delay allocating GPU tensors
|
|
||||||
self.text_config._attn_implementation = "vllm"
|
|
||||||
with init_on_device_without_buffers("meta"):
|
|
||||||
self.model: PreTrainedModel = AutoModel.from_config(
|
|
||||||
self.config,
|
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
trust_remote_code=self.model_config.trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove layers not on this pipeline parallel rank
|
|
||||||
self.pipeline_parallel()
|
|
||||||
# Substitute remaining layers with vLLM's layers as needed
|
|
||||||
self.recursive_replace()
|
|
||||||
# Create attention instances for KV cache allocation
|
|
||||||
self.attention_instances = self.create_attention_instances()
|
|
||||||
|
|
||||||
# Input embeddings
|
|
||||||
input_embeddings = self.model.get_input_embeddings()
|
|
||||||
if not isinstance(input_embeddings, PPMissingLayer):
|
|
||||||
# Some models use embedding scales
|
|
||||||
self.embed_scale = getattr(input_embeddings, "embed_scale", None)
|
|
||||||
names = ("embedding_size", "hidden_size")
|
|
||||||
embedding_dim = getattr_iter(self.text_config, names, None)
|
|
||||||
assert embedding_dim is not None
|
|
||||||
self.model.set_input_embeddings(
|
|
||||||
VocabParallelEmbedding(
|
|
||||||
self.text_config.vocab_size,
|
|
||||||
embedding_dim=embedding_dim,
|
|
||||||
org_num_embeddings=self.text_config.vocab_size,
|
|
||||||
quant_config=self.quant_config,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize any parameters that have not had their modules replaced
|
|
||||||
self.init_parameters(self.model)
|
|
||||||
|
|
||||||
# Pipeline parallel intermediate tensors
|
|
||||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
|
||||||
["hidden_states"], self.text_config.hidden_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def pipeline_parallel(self):
|
|
||||||
"""
|
|
||||||
Apply the model's pipeline parallelization plan.
|
|
||||||
"""
|
|
||||||
if self.pp_group.world_size <= 1:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.model.supports_pp_plan:
|
|
||||||
tip = get_feature_request_tip(
|
|
||||||
self.model_config.model, self.model_config.trust_remote_code
|
|
||||||
)
|
|
||||||
raise ValueError(
|
|
||||||
f"{type(self.model)} does not support pipeline parallel. {tip}"
|
|
||||||
)
|
|
||||||
|
|
||||||
module_lists = []
|
|
||||||
module_list_idx = None
|
|
||||||
pp_plan = list(self.model._pp_plan.keys())
|
|
||||||
for i, name in enumerate(pp_plan):
|
|
||||||
if isinstance(getattr(self.model, name), nn.ModuleList):
|
|
||||||
module_lists.append(name)
|
|
||||||
module_list_idx = i
|
|
||||||
|
|
||||||
if len(module_lists) > 1:
|
|
||||||
raise ValueError(
|
|
||||||
"Pipeline parallel of models with multiple `ModuleList`s "
|
|
||||||
"in the base model are not supported yet!"
|
|
||||||
)
|
|
||||||
if module_list_idx is None:
|
|
||||||
raise ValueError(f"Could not find `ModuleList` in {type(self.model)}")
|
|
||||||
|
|
||||||
# Layers before module list
|
|
||||||
for name in pp_plan[:module_list_idx]:
|
|
||||||
if self.pp_group.is_first_rank or (
|
|
||||||
self.text_config.tie_word_embeddings and self.pp_group.is_last_rank
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
setattr(self.model, name, PPMissingLayer())
|
|
||||||
|
|
||||||
# Module list
|
|
||||||
start_layer, end_layer = get_pp_indices(
|
|
||||||
self.text_config.num_hidden_layers,
|
|
||||||
self.pp_group.rank_in_group,
|
|
||||||
self.pp_group.world_size,
|
|
||||||
)
|
|
||||||
layers_name = pp_plan[module_list_idx]
|
|
||||||
layers = getattr(self.model, layers_name)
|
|
||||||
for i in range(len(layers)):
|
|
||||||
if start_layer <= i and i < end_layer:
|
|
||||||
continue
|
|
||||||
layers[i] = PPMissingLayer()
|
|
||||||
|
|
||||||
# Layers after module list
|
|
||||||
for name in pp_plan[module_list_idx + 1 :]:
|
|
||||||
# Modules that should be on last rank
|
|
||||||
if not self.pp_group.is_last_rank:
|
|
||||||
setattr(self.model, name, PPMissingLayer())
|
|
||||||
|
|
||||||
def recursive_replace(self):
|
|
||||||
"""Recursively replace modules in the model as needed.
|
|
||||||
|
|
||||||
Currently, this replaces:
|
|
||||||
|
|
||||||
- `nn.Linear` with vLLM's tensor parallel linear classes
|
|
||||||
- `*RMSNorm` with vLLM's `RMSNorm`
|
|
||||||
"""
|
|
||||||
tp_plan = self.model.tp_plan
|
|
||||||
|
|
||||||
if not tp_plan and self.tp_group.world_size > 1:
|
|
||||||
tip = get_feature_request_tip(
|
|
||||||
self.model_config.model, self.model_config.trust_remote_code
|
|
||||||
)
|
|
||||||
raise ValueError(
|
|
||||||
f"{type(self.model)} does not support tensor parallel. {tip}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefix the patterns because we always start from `self.model`
|
|
||||||
tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()}
|
|
||||||
|
|
||||||
def _recursive_replace(module: nn.Module, prefix: str):
|
|
||||||
for child_name, child_module in module.named_children():
|
|
||||||
new_module = child_module
|
|
||||||
qual_name = maybe_prefix(prefix, child_name)
|
|
||||||
if isinstance(child_module, nn.Linear):
|
|
||||||
generator = (p for p in tp_plan if re.match(p, qual_name))
|
|
||||||
pattern = next(generator, None)
|
|
||||||
# Some weight loaders expect all linear layers to inherit
|
|
||||||
# LinearBase, so we set a default style which causes any
|
|
||||||
# unspecified layers to be replaced with ReplicatedLinear
|
|
||||||
style = tp_plan.get(pattern, "replicate")
|
|
||||||
new_module = replace_linear_class(
|
|
||||||
child_module, style, self.quant_config, prefix=qual_name
|
|
||||||
)
|
|
||||||
elif child_module.__class__.__name__.endswith("RMSNorm"):
|
|
||||||
new_module = replace_rms_norm_class(
|
|
||||||
child_module, self.text_config.hidden_size
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_recursive_replace(child_module, prefix=qual_name)
|
|
||||||
|
|
||||||
if new_module is not child_module:
|
|
||||||
setattr(module, child_name, new_module)
|
|
||||||
log_replacement(qual_name, child_module, new_module)
|
|
||||||
|
|
||||||
_recursive_replace(self.model, prefix="model")
|
|
||||||
|
|
||||||
def create_attention_instances(
|
|
||||||
self, attn_type: AttentionType = AttentionType.DECODER
|
|
||||||
) -> dict[int, Attention]:
|
|
||||||
"""
|
|
||||||
Create `Attention` instances to inform KV cache allocation.
|
|
||||||
"""
|
|
||||||
num_heads = self.model_config.get_num_attention_heads(self.parallel_config)
|
|
||||||
head_size = self.model_config.get_head_size()
|
|
||||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
|
||||||
logits_soft_cap = getattr(self.text_config, "attn_logit_softcapping", None)
|
|
||||||
start, end = get_pp_indices(
|
|
||||||
self.text_config.num_hidden_layers,
|
|
||||||
self.pp_group.rank_in_group,
|
|
||||||
self.pp_group.world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_instances = {}
|
|
||||||
for i in range(start, end):
|
|
||||||
# Handle interleaved sliding window attention
|
|
||||||
per_layer_sliding_window = None
|
|
||||||
if (
|
|
||||||
hasattr(self.config, "layer_types")
|
|
||||||
and self.config.layer_types[i] == "sliding_attention"
|
|
||||||
):
|
|
||||||
per_layer_sliding_window = self.config.sliding_window
|
|
||||||
|
|
||||||
attention_instances[i] = Attention(
|
|
||||||
num_heads=num_heads,
|
|
||||||
head_size=head_size,
|
|
||||||
# NOTE: We use Llama scale as default, if it's set by
|
|
||||||
# Transformers, it's updated in vllm_flash_attention_forward
|
|
||||||
scale=head_size**-0.5,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
cache_config=self.cache_config,
|
|
||||||
quant_config=self.quant_config,
|
|
||||||
logits_soft_cap=logits_soft_cap,
|
|
||||||
per_layer_sliding_window=per_layer_sliding_window,
|
|
||||||
prefix=f"{i}.attn",
|
|
||||||
attn_type=attn_type,
|
|
||||||
)
|
|
||||||
return attention_instances
|
|
||||||
|
|
||||||
def init_parameters(self, module: nn.Module, dtype: torch.dtype | None = None):
|
|
||||||
"""
|
|
||||||
If a `parameter` is on the `meta` device, then its parent
|
|
||||||
`module` is the original module created by:
|
|
||||||
|
|
||||||
```python
|
|
||||||
with torch.device("meta"):
|
|
||||||
self.model: PreTrainedModel = AutoModel.from_config(...)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _init_parameters(module: nn.Module, dtype: torch.dtype | None):
|
|
||||||
for name, param in module.named_parameters(recurse=False):
|
|
||||||
if param.device == torch.device("meta"):
|
|
||||||
new_param = nn.Parameter(
|
|
||||||
torch.empty_like(
|
|
||||||
param.data,
|
|
||||||
dtype=dtype or self.model_config.dtype,
|
|
||||||
device=self.device_config.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
setattr(module, name, new_param)
|
|
||||||
for child in module.children():
|
|
||||||
_init_parameters(child, dtype)
|
|
||||||
|
|
||||||
_init_parameters(module, dtype)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor | None,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
intermediate_tensors: IntermediateTensors | None = None,
|
|
||||||
inputs_embeds: torch.Tensor | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor | IntermediateTensors:
|
|
||||||
if not self.pp_group.is_first_rank:
|
|
||||||
assert intermediate_tensors is not None
|
|
||||||
input_ids = None
|
|
||||||
inputs_embeds = intermediate_tensors["hidden_states"]
|
|
||||||
|
|
||||||
if input_ids is not None:
|
|
||||||
input_ids = input_ids[None, ...]
|
|
||||||
if inputs_embeds is not None:
|
|
||||||
inputs_embeds = inputs_embeds[None, ...]
|
|
||||||
|
|
||||||
if self.model_config.uses_mrope:
|
|
||||||
position_ids = positions[:, None]
|
|
||||||
else:
|
|
||||||
position_ids = positions[None, ...]
|
|
||||||
|
|
||||||
hidden_states = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=False,
|
|
||||||
position_ids=position_ids,
|
|
||||||
attention_instances=self.attention_instances,
|
|
||||||
return_dict=False,
|
|
||||||
**kwargs,
|
|
||||||
)[0][0, ...] # we remove batch dimension for now
|
|
||||||
|
|
||||||
if not self.pp_group.is_last_rank:
|
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def load_weights(
|
|
||||||
self,
|
|
||||||
weights: Iterable[tuple[str, torch.Tensor]],
|
|
||||||
) -> set[str]:
|
|
||||||
loader = AutoWeightsLoader(
|
|
||||||
self,
|
|
||||||
skip_prefixes=self.skip_prefixes,
|
|
||||||
skip_substrs=self.skip_substrs,
|
|
||||||
ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
|
|
||||||
ignore_unexpected_suffixes=self.ignore_unexpected_suffixes,
|
|
||||||
)
|
|
||||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
|
||||||
|
|
||||||
def check_version(self, min_version: str, feature: str):
|
|
||||||
installed = Version(transformers.__version__)
|
|
||||||
required = Version(min_version)
|
|
||||||
if installed < required:
|
|
||||||
raise ImportError(
|
|
||||||
f"Transformers backend requires transformers>={required} "
|
|
||||||
f"for {feature}, but got {installed}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
|
||||||
class TransformersForCausalLM(TransformersBase):
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
|
||||||
|
|
||||||
# Tell `TransformersBase.load_weights` to skip
|
|
||||||
# `lm_head` if the model has tied word embeddings
|
|
||||||
if self.text_config.tie_word_embeddings:
|
|
||||||
self.skip_prefixes.append("lm_head.")
|
|
||||||
|
|
||||||
if self.pp_group.is_last_rank:
|
|
||||||
self.unpadded_vocab_size = self.text_config.vocab_size
|
|
||||||
self.lm_head = ParallelLMHead(
|
|
||||||
self.text_config.vocab_size,
|
|
||||||
self.text_config.hidden_size,
|
|
||||||
quant_config=self.quant_config,
|
|
||||||
prefix=maybe_prefix(prefix, "lm_head"),
|
|
||||||
)
|
|
||||||
if self.text_config.tie_word_embeddings:
|
|
||||||
self.lm_head = self.lm_head.tie_weights(
|
|
||||||
self.model.get_input_embeddings()
|
|
||||||
)
|
|
||||||
|
|
||||||
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
|
|
||||||
self.logits_processor = LogitsProcessor(
|
|
||||||
self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.lm_head = PPMissingLayer()
|
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
||||||
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
|
||||||
if self.embed_scale is not None:
|
|
||||||
inputs_embeds *= self.embed_scale
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def compute_logits(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
) -> torch.Tensor | None:
|
|
||||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
|
||||||
MultiModalProcessor,
|
|
||||||
info=MultiModalProcessingInfo,
|
|
||||||
dummy_inputs=MultiModalDummyInputsBuilder,
|
|
||||||
)
|
|
||||||
@support_torch_compile(
|
|
||||||
# set `positions` to last dim to support Qwen-mrope
|
|
||||||
dynamic_arg_dims={
|
|
||||||
"input_ids": 0,
|
|
||||||
"positions": -1,
|
|
||||||
"intermediate_tensors": 0,
|
|
||||||
"inputs_embeds": 0,
|
|
||||||
},
|
|
||||||
enable_if=can_enable_torch_compile,
|
|
||||||
)
|
|
||||||
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
|
|
||||||
supports_multimodal_raw_input_only = True
|
|
||||||
merge_by_field_config = True
|
|
||||||
# Backwards compatibility for prev released models. State dicts back then
|
|
||||||
# had different formats and cannot be loaded with `AutoModel` mapping as is
|
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
|
||||||
orig_to_new_prefix={
|
|
||||||
"language_model.model": "model.language_model",
|
|
||||||
"text_model.model": "model.text_model",
|
|
||||||
"vision_tower": "model.vision_tower",
|
|
||||||
"vqmodel": "model.vqmodel",
|
|
||||||
"visual": "model.visual",
|
|
||||||
"vision_model": "model.vision_model",
|
|
||||||
"vision_embed_tokens": "model.vision_embed_tokens",
|
|
||||||
"image_newline": "model.image_newline",
|
|
||||||
"multi_modal_projector": "model.multi_modal_projector",
|
|
||||||
"text_model.lm_head": "lm_head",
|
|
||||||
"language_model.lm_head": "lm_head",
|
|
||||||
# Qwen models used "model" as the name for the language model.
|
|
||||||
# Therefore, we must map each of submodule explicitly to avoid
|
|
||||||
# conflicts with newer models that use "model.language_model".
|
|
||||||
"model.embed_tokens": "model.language_model.embed_tokens",
|
|
||||||
"model.layers": "model.language_model.layers",
|
|
||||||
"model.norm": "model.language_model.norm",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
|
||||||
|
|
||||||
self.dtype = vllm_config.model_config.dtype
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor | None,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
intermediate_tensors: IntermediateTensors | None = None,
|
|
||||||
inputs_embeds: torch.Tensor | None = None,
|
|
||||||
**kwargs: object,
|
|
||||||
) -> torch.Tensor | IntermediateTensors:
|
|
||||||
# Gemma3 and PaliGemma needs `token_type_ids` to work correctly
|
|
||||||
# Other models will not have `token_type_ids` in kwargs
|
|
||||||
kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"}
|
|
||||||
model_output = super().forward(
|
|
||||||
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
|
|
||||||
)
|
|
||||||
return model_output
|
|
||||||
|
|
||||||
def get_language_model(self) -> torch.nn.Module:
|
|
||||||
"""`TransformersForMultimodalLM` does not contain a vLLM language model class.
|
|
||||||
Therefore, in order to return a language model vLLM class, we use a wrapper to
|
|
||||||
give `self` the same interface as `TransformersForCausalLM`."""
|
|
||||||
|
|
||||||
class LanguageModelWrapper(TransformersForCausalLM):
|
|
||||||
def __init__(self, multimodal_model):
|
|
||||||
# Don't call super().__init__() to avoid re-initialization
|
|
||||||
self.__dict__.update(multimodal_model.__dict__)
|
|
||||||
|
|
||||||
model = getattr_iter(self.model, ("language_model", "text_model"), None)
|
|
||||||
|
|
||||||
return LanguageModelWrapper(self)
|
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs):
|
|
||||||
pixel_values: torch.Tensor | None = kwargs.pop("pixel_values", None)
|
|
||||||
image_embeds: torch.Tensor | None = kwargs.pop("image_embeds", None)
|
|
||||||
# Model might use `image_patches` instead of `pixel_values`
|
|
||||||
if pixel_values is None:
|
|
||||||
pixel_values = kwargs.pop("image_patches", None)
|
|
||||||
|
|
||||||
if image_embeds is not None:
|
|
||||||
return image_embeds
|
|
||||||
|
|
||||||
if pixel_values is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
num_image_patches = kwargs.pop("num_image_patches")
|
|
||||||
kwargs.pop("token_type_ids", None) # used only in `forward`
|
|
||||||
if pixel_values is not None:
|
|
||||||
vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
|
|
||||||
|
|
||||||
if isinstance(vision_embeddings, torch.Tensor):
|
|
||||||
if vision_embeddings.ndim == 2:
|
|
||||||
vision_embeddings = vision_embeddings.unsqueeze(0)
|
|
||||||
|
|
||||||
# Embeddings have to be 2D tensors of length `num_images`
|
|
||||||
# but transformers returns concat tensors if each patch
|
|
||||||
# is of different size. We split it back to make vLLM happy
|
|
||||||
vision_embeddings = torch.split(
|
|
||||||
vision_embeddings, num_image_patches.flatten().tolist()
|
|
||||||
)
|
|
||||||
vision_embeddings = [
|
|
||||||
embed.flatten(start_dim=0, end_dim=-2)
|
|
||||||
for embed in vision_embeddings
|
|
||||||
]
|
|
||||||
|
|
||||||
return vision_embeddings
|
|
||||||
|
|
||||||
get_input_embeddings = SupportsMultiModal.get_input_embeddings
|
|
||||||
127
vllm/model_executor/models/transformers/__init__.py
Normal file
127
vllm/model_executor/models/transformers/__init__.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Copyright 2024 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Wrapper around `transformers` models"""
|
||||||
|
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.model_executor.models.transformers.base import Base
|
||||||
|
from vllm.model_executor.models.transformers.causal import CausalMixin
|
||||||
|
from vllm.model_executor.models.transformers.legacy import LegacyMixin
|
||||||
|
from vllm.model_executor.models.transformers.moe import MoEMixin
|
||||||
|
from vllm.model_executor.models.transformers.multimodal import (
|
||||||
|
DYNAMIC_ARG_DIMS,
|
||||||
|
MultiModalDummyInputsBuilder,
|
||||||
|
MultiModalMixin,
|
||||||
|
MultiModalProcessingInfo,
|
||||||
|
MultiModalProcessor,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.models.transformers.pooling import (
|
||||||
|
EmbeddingMixin,
|
||||||
|
SequenceClassificationMixin,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.models.transformers.utils import can_enable_torch_compile
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
|
|
||||||
|
# Text only models
|
||||||
|
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||||
|
class TransformersForCausalLM(CausalMixin, Base): ...
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||||
|
class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ...
|
||||||
|
|
||||||
|
|
||||||
|
# Multimodal models
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
MultiModalProcessor,
|
||||||
|
info=MultiModalProcessingInfo,
|
||||||
|
dummy_inputs=MultiModalDummyInputsBuilder,
|
||||||
|
)
|
||||||
|
@support_torch_compile(
|
||||||
|
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
|
||||||
|
)
|
||||||
|
class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ...
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
MultiModalProcessor,
|
||||||
|
info=MultiModalProcessingInfo,
|
||||||
|
dummy_inputs=MultiModalDummyInputsBuilder,
|
||||||
|
)
|
||||||
|
@support_torch_compile(
|
||||||
|
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
|
||||||
|
)
|
||||||
|
class TransformersMultiModalMoEForCausalLM(
|
||||||
|
MoEMixin, MultiModalMixin, CausalMixin, Base
|
||||||
|
): ...
|
||||||
|
|
||||||
|
|
||||||
|
# Embedding models
|
||||||
|
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||||
|
class TransformersEmbeddingModel(EmbeddingMixin, LegacyMixin, Base): ...
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||||
|
class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ...
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
MultiModalProcessor,
|
||||||
|
info=MultiModalProcessingInfo,
|
||||||
|
dummy_inputs=MultiModalDummyInputsBuilder,
|
||||||
|
)
|
||||||
|
@support_torch_compile(
|
||||||
|
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
|
||||||
|
)
|
||||||
|
class TransformersMultiModalEmbeddingModel(EmbeddingMixin, MultiModalMixin, Base): ...
|
||||||
|
|
||||||
|
|
||||||
|
# Sequence classification models
|
||||||
|
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||||
|
class TransformersForSequenceClassification(
|
||||||
|
SequenceClassificationMixin, LegacyMixin, Base
|
||||||
|
): ...
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||||
|
class TransformersMoEForSequenceClassification(
|
||||||
|
SequenceClassificationMixin, MoEMixin, Base
|
||||||
|
): ...
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
MultiModalProcessor,
|
||||||
|
info=MultiModalProcessingInfo,
|
||||||
|
dummy_inputs=MultiModalDummyInputsBuilder,
|
||||||
|
)
|
||||||
|
@support_torch_compile(
|
||||||
|
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
|
||||||
|
)
|
||||||
|
class TransformersMultiModalForSequenceClassification(
|
||||||
|
SequenceClassificationMixin, MultiModalMixin, Base
|
||||||
|
): ...
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str):
|
||||||
|
"""Handle imports of non-existent classes with a helpful error message."""
|
||||||
|
if name not in globals():
|
||||||
|
raise AttributeError(
|
||||||
|
"The Transformers backend does not currently have a class to handle "
|
||||||
|
f"the requested model type: {name}. Please open an issue at "
|
||||||
|
"https://github.com/vllm-project/vllm/issues/new"
|
||||||
|
)
|
||||||
|
return globals()[name]
|
||||||
435
vllm/model_executor/models/transformers/base.py
Normal file
435
vllm/model_executor/models/transformers/base.py
Normal file
@@ -0,0 +1,435 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Copyright 2024 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Transformers backend base class."""
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from packaging.version import Version
|
||||||
|
from torch import nn
|
||||||
|
from transformers import AutoModel
|
||||||
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
|
from vllm.attention import Attention, AttentionType
|
||||||
|
from vllm.config.utils import getattr_iter
|
||||||
|
from vllm.distributed import get_pp_group, get_tp_group
|
||||||
|
from vllm.distributed.utils import get_pp_indices
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
|
from vllm.model_executor.models.interfaces import (
|
||||||
|
SupportsLoRA,
|
||||||
|
SupportsPP,
|
||||||
|
SupportsQuant,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.models.interfaces_base import VllmModel
|
||||||
|
from vllm.model_executor.models.transformers.utils import (
|
||||||
|
get_feature_request_tip,
|
||||||
|
init_on_device_without_buffers,
|
||||||
|
log_replacement,
|
||||||
|
replace_linear_class,
|
||||||
|
replace_rms_norm_class,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.models.utils import (
|
||||||
|
AutoWeightsLoader,
|
||||||
|
PPMissingLayer,
|
||||||
|
make_empty_intermediate_tensors_factory,
|
||||||
|
maybe_prefix,
|
||||||
|
)
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
else:
|
||||||
|
PreTrainedModel = object
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_flash_attention_forward(
|
||||||
|
# Transformers args
|
||||||
|
module: torch.nn.Module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
# Transformers kwargs
|
||||||
|
scaling: float | None = None,
|
||||||
|
# vLLM kwargs
|
||||||
|
attention_instances: dict[int, Attention] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self_attn = attention_instances[module.layer_idx]
|
||||||
|
if scaling is not None:
|
||||||
|
self_attn.impl.scale = float(scaling)
|
||||||
|
hidden = query.shape[-2]
|
||||||
|
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||||
|
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
|
||||||
|
return self_attn.forward(query, key, value), None
|
||||||
|
|
||||||
|
|
||||||
|
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
|
||||||
|
|
||||||
|
|
||||||
|
class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||||
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
logger.info("Using Transformers backend.")
|
||||||
|
|
||||||
|
self.config = vllm_config.model_config.hf_config
|
||||||
|
self.text_config = self.config.get_text_config()
|
||||||
|
self.cache_config = vllm_config.cache_config
|
||||||
|
self.device_config = vllm_config.device_config
|
||||||
|
self.model_config = vllm_config.model_config
|
||||||
|
self.parallel_config = vllm_config.parallel_config
|
||||||
|
self.quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
|
self.tp_group = get_tp_group()
|
||||||
|
|
||||||
|
# Weights to skip in `self.load_weights`
|
||||||
|
self.skip_prefixes: list[str] = []
|
||||||
|
"""Skip loading weights whose qualname starts with these prefixes."""
|
||||||
|
self.skip_substrs: list[str] = []
|
||||||
|
"""Skip loading weights whose qualname contains these substrings."""
|
||||||
|
self.ignore_unexpected_prefixes: list[str] = []
|
||||||
|
"""Ignore unexpected weights whose qualname starts with these prefixes.
|
||||||
|
"""
|
||||||
|
self.ignore_unexpected_suffixes: list[str] = []
|
||||||
|
"""Ignore unexpected weights whose qualname ends with these suffixes."""
|
||||||
|
|
||||||
|
if self.quant_config:
|
||||||
|
quant_method_name = self.quant_config.get_name()
|
||||||
|
# Check for unsupported quantization methods.
|
||||||
|
if quant_method_name == "mxfp4":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Transformers backend does not support MXFP4 quantization yet."
|
||||||
|
)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if "gptq" in quant_method_name:
|
||||||
|
self.ignore_unexpected_suffixes.append(".bias")
|
||||||
|
|
||||||
|
# Set correct attn and init on "meta" to delay allocating GPU tensors
|
||||||
|
self.text_config._attn_implementation = "vllm"
|
||||||
|
with init_on_device_without_buffers("meta"):
|
||||||
|
self.model: PreTrainedModel = AutoModel.from_config(
|
||||||
|
self.config,
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove layers not on this pipeline parallel rank
|
||||||
|
self.pipeline_parallel()
|
||||||
|
# Substitute remaining layers with vLLM's layers as needed
|
||||||
|
self.recursive_replace()
|
||||||
|
# Create attention instances for KV cache allocation
|
||||||
|
self.attention_instances = self.create_attention_instances()
|
||||||
|
|
||||||
|
# Input embeddings
|
||||||
|
input_embeddings = self.model.get_input_embeddings()
|
||||||
|
if not isinstance(input_embeddings, PPMissingLayer):
|
||||||
|
# Some models scale embeddings inside the input embedding layer
|
||||||
|
self.embed_scale = getattr(input_embeddings, "embed_scale", None)
|
||||||
|
names = ("embedding_size", "hidden_size")
|
||||||
|
embedding_dim = getattr_iter(self.text_config, names, None)
|
||||||
|
assert embedding_dim is not None
|
||||||
|
self.model.set_input_embeddings(
|
||||||
|
VocabParallelEmbedding(
|
||||||
|
self.text_config.vocab_size,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
org_num_embeddings=self.text_config.vocab_size,
|
||||||
|
quant_config=self.quant_config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize any parameters that have not had their modules replaced
|
||||||
|
self.init_parameters(self.model)
|
||||||
|
|
||||||
|
# Pipeline parallel intermediate tensors
|
||||||
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||||
|
["hidden_states"], self.text_config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def pipeline_parallel(self):
|
||||||
|
"""
|
||||||
|
Apply the model's pipeline parallelization plan.
|
||||||
|
"""
|
||||||
|
if self.pp_group.world_size <= 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.model.supports_pp_plan:
|
||||||
|
tip = get_feature_request_tip(
|
||||||
|
self.model_config.model, self.model_config.trust_remote_code
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"{type(self.model)} does not support pipeline parallel. {tip}"
|
||||||
|
)
|
||||||
|
|
||||||
|
module_lists = []
|
||||||
|
module_list_idx = None
|
||||||
|
pp_plan = list(self.model._pp_plan.keys())
|
||||||
|
for i, name in enumerate(pp_plan):
|
||||||
|
if isinstance(getattr(self.model, name), nn.ModuleList):
|
||||||
|
module_lists.append(name)
|
||||||
|
module_list_idx = i
|
||||||
|
|
||||||
|
if len(module_lists) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Pipeline parallel of models with multiple `ModuleList`s "
|
||||||
|
"in the base model are not supported yet!"
|
||||||
|
)
|
||||||
|
if module_list_idx is None:
|
||||||
|
raise ValueError(f"Could not find `ModuleList` in {type(self.model)}")
|
||||||
|
|
||||||
|
# Layers before module list
|
||||||
|
for name in pp_plan[:module_list_idx]:
|
||||||
|
if self.pp_group.is_first_rank or (
|
||||||
|
self.text_config.tie_word_embeddings and self.pp_group.is_last_rank
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
setattr(self.model, name, PPMissingLayer())
|
||||||
|
|
||||||
|
# Module list
|
||||||
|
start_layer, end_layer = get_pp_indices(
|
||||||
|
self.text_config.num_hidden_layers,
|
||||||
|
self.pp_group.rank_in_group,
|
||||||
|
self.pp_group.world_size,
|
||||||
|
)
|
||||||
|
layers_name = pp_plan[module_list_idx]
|
||||||
|
layers = getattr(self.model, layers_name)
|
||||||
|
for i in range(len(layers)):
|
||||||
|
if start_layer <= i and i < end_layer:
|
||||||
|
continue
|
||||||
|
layers[i] = PPMissingLayer()
|
||||||
|
|
||||||
|
# Layers after module list
|
||||||
|
for name in pp_plan[module_list_idx + 1 :]:
|
||||||
|
# Modules that should be on last rank
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
setattr(self.model, name, PPMissingLayer())
|
||||||
|
|
||||||
|
def recursive_replace(self):
|
||||||
|
"""Recursively replace modules in the model as needed.
|
||||||
|
|
||||||
|
Currently, this replaces:
|
||||||
|
|
||||||
|
- `nn.Linear` with vLLM's tensor parallel linear classes
|
||||||
|
- `*RMSNorm` with vLLM's `RMSNorm`
|
||||||
|
"""
|
||||||
|
tp_plan = self.model.tp_plan
|
||||||
|
|
||||||
|
if not tp_plan and self.tp_group.world_size > 1:
|
||||||
|
tip = get_feature_request_tip(
|
||||||
|
self.model_config.model, self.model_config.trust_remote_code
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"{type(self.model)} does not support tensor parallel. {tip}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefix the patterns because we always start from `self.model`
|
||||||
|
tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()}
|
||||||
|
|
||||||
|
def _recursive_replace(module: nn.Module, prefix: str):
|
||||||
|
for child_name, child_module in module.named_children():
|
||||||
|
new_module = child_module
|
||||||
|
qual_name = maybe_prefix(prefix, child_name)
|
||||||
|
if isinstance(child_module, nn.Linear):
|
||||||
|
generator = (p for p in tp_plan if re.match(p, qual_name))
|
||||||
|
pattern = next(generator, None)
|
||||||
|
# Some weight loaders expect all linear layers to inherit
|
||||||
|
# LinearBase, so we set a default style which causes any
|
||||||
|
# unspecified layers to be replaced with ReplicatedLinear
|
||||||
|
style = tp_plan.get(pattern, "replicate")
|
||||||
|
new_module = replace_linear_class(
|
||||||
|
child_module, style, self.quant_config, prefix=qual_name
|
||||||
|
)
|
||||||
|
elif child_module.__class__.__name__.endswith("RMSNorm"):
|
||||||
|
new_module = replace_rms_norm_class(
|
||||||
|
child_module, self.text_config.hidden_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_recursive_replace(child_module, prefix=qual_name)
|
||||||
|
|
||||||
|
if new_module is not child_module:
|
||||||
|
setattr(module, child_name, new_module)
|
||||||
|
log_replacement(qual_name, child_module, new_module)
|
||||||
|
|
||||||
|
_recursive_replace(self.model, prefix="model")
|
||||||
|
|
||||||
|
def create_attention_instances(self) -> dict[int, Attention]:
|
||||||
|
"""
|
||||||
|
Create `Attention` instances to inform KV cache allocation.
|
||||||
|
"""
|
||||||
|
text_config = self.text_config
|
||||||
|
|
||||||
|
num_heads = self.model_config.get_num_attention_heads(self.parallel_config)
|
||||||
|
head_size = self.model_config.get_head_size()
|
||||||
|
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||||
|
logits_soft_cap = getattr(text_config, "attn_logit_softcapping", None)
|
||||||
|
|
||||||
|
# In encoder models, the attention layers will have `is_causal=False`
|
||||||
|
is_encoder = lambda module: not getattr(module, "is_causal", True)
|
||||||
|
has_encoder = lambda model: any(is_encoder(m) for m in model.modules())
|
||||||
|
is_multimodal = lambda config: config != config.get_text_config()
|
||||||
|
# vLLM does not support encoder-decoder models, so if any encoder layer is
|
||||||
|
# found in a text only model, we assume the whole model is an encoder model
|
||||||
|
if has_encoder(self.model) and not is_multimodal(self.config):
|
||||||
|
self.check_version("4.57.0.dev0", "encoder models support")
|
||||||
|
attn_type = AttentionType.ENCODER_ONLY
|
||||||
|
else:
|
||||||
|
attn_type = AttentionType.DECODER
|
||||||
|
|
||||||
|
pp_rank = self.pp_group.rank_in_group
|
||||||
|
pp_size = self.pp_group.world_size
|
||||||
|
start, end = get_pp_indices(text_config.num_hidden_layers, pp_rank, pp_size)
|
||||||
|
|
||||||
|
attention_instances = {}
|
||||||
|
for i in range(start, end):
|
||||||
|
# Handle interleaved sliding window attention
|
||||||
|
per_layer_sliding_window = None
|
||||||
|
if (
|
||||||
|
hasattr(self.config, "layer_types")
|
||||||
|
and self.config.layer_types[i] == "sliding_attention"
|
||||||
|
):
|
||||||
|
per_layer_sliding_window = self.config.sliding_window
|
||||||
|
|
||||||
|
attention_instances[i] = Attention(
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
# NOTE: We use Llama scale as default, if it's set by
|
||||||
|
# Transformers, it's updated in vllm_flash_attention_forward
|
||||||
|
scale=head_size**-0.5,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
cache_config=self.cache_config,
|
||||||
|
quant_config=self.quant_config,
|
||||||
|
logits_soft_cap=logits_soft_cap,
|
||||||
|
per_layer_sliding_window=per_layer_sliding_window,
|
||||||
|
prefix=f"{i}.attn",
|
||||||
|
attn_type=attn_type,
|
||||||
|
)
|
||||||
|
return attention_instances
|
||||||
|
|
||||||
|
def init_parameters(self, module: nn.Module, dtype: torch.dtype | None = None):
|
||||||
|
"""
|
||||||
|
If a `parameter` is on the `meta` device, then its parent
|
||||||
|
`module` is the original module created by:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with torch.device("meta"):
|
||||||
|
self.model: "PreTrainedModel" = AutoModel.from_config(...)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _init_parameters(module: nn.Module, dtype: torch.dtype | None):
|
||||||
|
for name, param in module.named_parameters(recurse=False):
|
||||||
|
if param.device == torch.device("meta"):
|
||||||
|
new_param = nn.Parameter(
|
||||||
|
torch.empty_like(
|
||||||
|
param.data,
|
||||||
|
dtype=dtype or self.model_config.dtype,
|
||||||
|
device=self.device_config.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
setattr(module, name, new_param)
|
||||||
|
for child in module.children():
|
||||||
|
_init_parameters(child, dtype)
|
||||||
|
|
||||||
|
_init_parameters(module, dtype)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
||||||
|
if self.embed_scale is not None:
|
||||||
|
inputs_embeds *= self.embed_scale
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor | None,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: IntermediateTensors | None = None,
|
||||||
|
inputs_embeds: torch.Tensor | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor | IntermediateTensors:
|
||||||
|
if not self.pp_group.is_first_rank:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
input_ids = None
|
||||||
|
inputs_embeds = intermediate_tensors["hidden_states"]
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
input_ids = input_ids[None, ...]
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
inputs_embeds = inputs_embeds[None, ...]
|
||||||
|
|
||||||
|
# If the model scales embeddings inside the input embedding layer we must
|
||||||
|
# ensure they are scaled here since VocabParallelEmbedding will not do it
|
||||||
|
if (
|
||||||
|
self.embed_scale is not None
|
||||||
|
and input_ids is not None
|
||||||
|
and inputs_embeds is None
|
||||||
|
):
|
||||||
|
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
|
if self.model_config.uses_mrope:
|
||||||
|
position_ids = positions[:, None]
|
||||||
|
else:
|
||||||
|
position_ids = positions[None, ...]
|
||||||
|
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=False,
|
||||||
|
position_ids=position_ids,
|
||||||
|
attention_instances=self.attention_instances,
|
||||||
|
return_dict=False,
|
||||||
|
**kwargs,
|
||||||
|
)[0][0, ...] # we remove batch dimension for now
|
||||||
|
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(
|
||||||
|
self,
|
||||||
|
weights: Iterable[tuple[str, torch.Tensor]],
|
||||||
|
) -> set[str]:
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=self.skip_prefixes,
|
||||||
|
skip_substrs=self.skip_substrs,
|
||||||
|
ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
|
||||||
|
ignore_unexpected_suffixes=self.ignore_unexpected_suffixes,
|
||||||
|
)
|
||||||
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_version(min_version: str, feature: str):
|
||||||
|
installed = Version(transformers.__version__)
|
||||||
|
required = Version(min_version)
|
||||||
|
if installed < required:
|
||||||
|
raise ImportError(
|
||||||
|
f"Transformers backend requires transformers>={required} "
|
||||||
|
f"for {feature}, but got {installed}"
|
||||||
|
)
|
||||||
66
vllm/model_executor/models/transformers/causal.py
Normal file
66
vllm/model_executor/models/transformers/causal.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Copyright 2024 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Transformers backend mixin for causal language models."""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
from vllm.model_executor.models.interfaces_base import VllmModelForTextGeneration
|
||||||
|
from vllm.model_executor.models.utils import PPMissingLayer, maybe_prefix
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
|
||||||
|
class CausalMixin(VllmModelForTextGeneration):
|
||||||
|
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
|
||||||
|
# Skip VllmModelForTextGeneration.__init__ and call the next class in MRO
|
||||||
|
super(VllmModelForTextGeneration, self).__init__(
|
||||||
|
vllm_config=vllm_config, prefix=prefix
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tell `Base.load_weights` to skip
|
||||||
|
# `lm_head` if the model has tied word embeddings
|
||||||
|
if self.text_config.tie_word_embeddings:
|
||||||
|
self.skip_prefixes.append("lm_head.")
|
||||||
|
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
self.unpadded_vocab_size = self.text_config.vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.text_config.vocab_size,
|
||||||
|
self.text_config.hidden_size,
|
||||||
|
quant_config=self.quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "lm_head"),
|
||||||
|
)
|
||||||
|
if self.text_config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.lm_head.tie_weights(
|
||||||
|
self.model.get_input_embeddings()
|
||||||
|
)
|
||||||
|
|
||||||
|
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
|
||||||
|
self.logits_processor = LogitsProcessor(
|
||||||
|
self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: "torch.Tensor") -> "torch.Tensor | None":
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||||
|
return logits
|
||||||
97
vllm/model_executor/models/transformers/legacy.py
Normal file
97
vllm/model_executor/models/transformers/legacy.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Copyright 2024 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Transformers backend mixin for legacy models."""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
|
||||||
|
class LegacyMixin:
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
|
# These are applied in order, so the order matters!
|
||||||
|
orig_to_new_prefix={
|
||||||
|
# Handle BERT-like models
|
||||||
|
"roberta": "model",
|
||||||
|
"bert": "model",
|
||||||
|
# Add `model.` prefix for base model checkpoints
|
||||||
|
"": "model.",
|
||||||
|
# Remove `model.` prefix if it was already there
|
||||||
|
"model.model.": "model.",
|
||||||
|
# Classifier/scoring heads will be adjacent to `model`
|
||||||
|
"model.score": "classifier",
|
||||||
|
"model.classifier": "classifier",
|
||||||
|
},
|
||||||
|
orig_to_new_suffix={
|
||||||
|
# Replace legacy suffixes used for norms
|
||||||
|
".gamma": ".weight",
|
||||||
|
".beta": ".bias",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
|
||||||
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
|
# Skip unsupported/unwanted output embeddings layers
|
||||||
|
self.skip_prefixes.extend(
|
||||||
|
[
|
||||||
|
"model.lm_head.",
|
||||||
|
"model.predictions.",
|
||||||
|
"model.qa_outputs.",
|
||||||
|
"model.embeddings_project.",
|
||||||
|
"model.discriminator_predictions.",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Some encoder models have the position_ids buffer in the checkpoint.
|
||||||
|
# vLLM will always pass position_ids as an argument, so we skip loading
|
||||||
|
# the buffer if it exists
|
||||||
|
self.skip_substrs.append("position_ids")
|
||||||
|
|
||||||
|
# Some encoder models have the bias of the final classifier layer
|
||||||
|
# in the checkpoint. vLLM does not use this bias, so we skip loading
|
||||||
|
# it if it exists
|
||||||
|
self.skip_substrs.append("score.bias")
|
||||||
|
|
||||||
|
# roberta-like models an extra padding in positions.
|
||||||
|
# FIXME(Isotr0py): This is quite hacky for roberta edge case,
|
||||||
|
# we should find a better way to handle this.
|
||||||
|
self.is_roberta = "roberta" in self.text_config.model_type
|
||||||
|
self.padding_idx = self.text_config.pad_token_id
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor | None,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: IntermediateTensors | None = None,
|
||||||
|
inputs_embeds: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor | IntermediateTensors:
|
||||||
|
if self.is_roberta:
|
||||||
|
# RoBERTa-specific positions padding
|
||||||
|
positions += self.padding_idx + 1
|
||||||
|
return super().forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
@@ -14,31 +14,27 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Wrapper around `transformers` MoE models."""
|
"""Transformers backend mixin for Mixture of Experts (MoE) models."""
|
||||||
|
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
|
||||||
from vllm.config.utils import getattr_iter
|
from vllm.config.utils import getattr_iter
|
||||||
from vllm.distributed import get_dp_group, get_ep_group
|
from vllm.distributed import get_dp_group, get_ep_group
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
|
from vllm.model_executor.models.interfaces import MixtureOfExperts
|
||||||
|
from vllm.model_executor.models.utils import maybe_prefix
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
from .interfaces import MixtureOfExperts, SupportsMultiModal
|
from .utils import log_replacement
|
||||||
from .transformers import (
|
|
||||||
TransformersBase,
|
if TYPE_CHECKING:
|
||||||
TransformersForCausalLM,
|
from vllm.config import VllmConfig
|
||||||
TransformersForMultimodalLM,
|
|
||||||
can_enable_torch_compile,
|
|
||||||
log_replacement,
|
|
||||||
)
|
|
||||||
from .utils import maybe_prefix
|
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("transformers_fused_moe")
|
@CustomOp.register("transformers_fused_moe")
|
||||||
@@ -117,11 +113,11 @@ direct_register_custom_op(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TransformersMoEBase(TransformersBase, MixtureOfExperts):
|
class MoEMixin(MixtureOfExperts):
|
||||||
def __init__(self, *, vllm_config, prefix=""):
|
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
|
||||||
self.check_version("4.57.0.dev0", "MoE models support")
|
self.check_version("4.57.0.dev0", "MoE models support")
|
||||||
self.ep_group = get_ep_group()
|
# Skip MixtureOfExperts.__init__ and call the next class in MRO
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
super(MixtureOfExperts, self).__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
def set_eplb_state(
|
def set_eplb_state(
|
||||||
self,
|
self,
|
||||||
@@ -242,7 +238,7 @@ class TransformersMoEBase(TransformersBase, MixtureOfExperts):
|
|||||||
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
|
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
|
||||||
|
|
||||||
# MixtureOfExperts mixin settings
|
# MixtureOfExperts mixin settings
|
||||||
ep_size = self.ep_group.world_size
|
ep_size = get_ep_group().world_size
|
||||||
|
|
||||||
self.mlp_layers = [] # Used for MixtureOfExperts methods
|
self.mlp_layers = [] # Used for MixtureOfExperts methods
|
||||||
self.expert_weights = []
|
self.expert_weights = []
|
||||||
@@ -316,24 +312,5 @@ class TransformersMoEBase(TransformersBase, MixtureOfExperts):
|
|||||||
_recursive_replace(child_module, prefix=qual_name)
|
_recursive_replace(child_module, prefix=qual_name)
|
||||||
|
|
||||||
_recursive_replace(self.model, prefix="model")
|
_recursive_replace(self.model, prefix="model")
|
||||||
# Continue with the replacement of layers in TransformersBase
|
# Continue with the replacement of layers in Base
|
||||||
super().recursive_replace()
|
super().recursive_replace()
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
|
||||||
class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile(
|
|
||||||
# set `positions` to last dim to support Qwen-mrope
|
|
||||||
dynamic_arg_dims={
|
|
||||||
"input_ids": 0,
|
|
||||||
"positions": -1,
|
|
||||||
"intermediate_tensors": 0,
|
|
||||||
"inputs_embeds": 0,
|
|
||||||
},
|
|
||||||
enable_if=can_enable_torch_compile,
|
|
||||||
)
|
|
||||||
class TransformersMoEForMultimodalLM(TransformersMoEBase, TransformersForMultimodalLM):
|
|
||||||
get_input_embeddings = SupportsMultiModal.get_input_embeddings
|
|
||||||
396
vllm/model_executor/models/transformers/multimodal.py
Normal file
396
vllm/model_executor/models/transformers/multimodal.py
Normal file
@@ -0,0 +1,396 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Copyright 2024 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Transformers backend mixin for multi-modal models."""
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config.utils import getattr_iter
|
||||||
|
from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal
|
||||||
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
|
from vllm.multimodal import MultiModalKwargsItems
|
||||||
|
from vllm.multimodal.inputs import (
|
||||||
|
MultiModalDataDict,
|
||||||
|
MultiModalFieldConfig,
|
||||||
|
MultiModalInputs,
|
||||||
|
MultiModalUUIDDict,
|
||||||
|
PlaceholderRange,
|
||||||
|
)
|
||||||
|
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
|
||||||
|
from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo
|
||||||
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import BatchFeature, PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
|
|
||||||
|
DYNAMIC_ARG_DIMS = {
|
||||||
|
"input_ids": 0,
|
||||||
|
# set `positions` to last dim to support Qwen-mrope
|
||||||
|
"positions": -1,
|
||||||
|
"intermediate_tensors": 0,
|
||||||
|
"inputs_embeds": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalProcessingInfo(BaseProcessingInfo):
|
||||||
|
def get_supported_mm_limits(self):
|
||||||
|
return {"image": None}
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
|
||||||
|
return {"image": self.get_max_image_tokens()}
|
||||||
|
|
||||||
|
def get_max_image_tokens(self) -> int:
|
||||||
|
width, height = self.get_max_image_size()
|
||||||
|
processor = self.get_hf_processor()
|
||||||
|
multimodal_config = self.ctx.model_config.multimodal_config
|
||||||
|
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
|
||||||
|
mm_tokens = processor._get_num_multimodal_tokens(
|
||||||
|
image_sizes=([height, width],), **mm_processor_kwargs
|
||||||
|
)
|
||||||
|
image_tokens = mm_tokens["num_image_tokens"][0]
|
||||||
|
return image_tokens
|
||||||
|
|
||||||
|
def get_max_image_size(self):
|
||||||
|
return 10_000, 10_000 # hardcode for arbitrary very large size
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]):
|
||||||
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||||
|
num_images = mm_counts.get("image", 0)
|
||||||
|
|
||||||
|
processor = self.info.get_hf_processor()
|
||||||
|
if "gemma3" in processor.__class__.__name__.lower():
|
||||||
|
image_token = processor.boi_token
|
||||||
|
else:
|
||||||
|
image_token = getattr(processor, "image_token", "")
|
||||||
|
return image_token * num_images
|
||||||
|
|
||||||
|
def get_dummy_mm_data(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
mm_options: Mapping[str, "BaseDummyOptions"] | None = None,
|
||||||
|
) -> MultiModalDataDict:
|
||||||
|
num_images = mm_counts.get("image", 0)
|
||||||
|
|
||||||
|
target_width, target_height = self.info.get_max_image_size()
|
||||||
|
|
||||||
|
image_overrides = mm_options.get("image") if mm_options else None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"image": self._get_dummy_images(
|
||||||
|
width=target_width,
|
||||||
|
height=target_height,
|
||||||
|
num_images=num_images,
|
||||||
|
overrides=image_overrides,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
|
||||||
|
def _get_prompt_updates(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
out_mm_kwargs: MultiModalKwargsItems,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Given the original multi-modal items for this modality
|
||||||
|
and HF-processed data, output the updates to perform.
|
||||||
|
|
||||||
|
The information returned by this method is used to update token inputs
|
||||||
|
which bypass the HF processor. It is also used to update the output of
|
||||||
|
HF processor if the HF process does not apply prompt updates to text
|
||||||
|
inputs.
|
||||||
|
|
||||||
|
Moreover, this information is critical to determine the token positions
|
||||||
|
in order to construct :class:`~vllm-multimodal.input.PlaceholderRange`
|
||||||
|
for each multi-modal item.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: "BatchFeature",
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
# HF Processors always return a mask but vLLM doesn't need it
|
||||||
|
hf_inputs.pop("attention_mask", None)
|
||||||
|
num_image_patches = hf_inputs.get("num_image_patches")
|
||||||
|
mm_fields = {
|
||||||
|
key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches)
|
||||||
|
for key in hf_inputs
|
||||||
|
}
|
||||||
|
mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"image", num_image_patches
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keep these as batched, as they always have batch size as first dim
|
||||||
|
mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image")
|
||||||
|
mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image")
|
||||||
|
mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
|
||||||
|
return mm_fields
|
||||||
|
|
||||||
|
def _get_hf_mm_data(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
) -> tuple[Mapping[str, object], Mapping[str, object]]:
|
||||||
|
"""
|
||||||
|
In contrast to the base class, this method always adds
|
||||||
|
`return_mm_token_type_ids` to the processor data
|
||||||
|
"""
|
||||||
|
processor_data, passthrough_data = super()._get_hf_mm_data(mm_items)
|
||||||
|
processor_data["return_mm_token_type_ids"] = True
|
||||||
|
return processor_data, passthrough_data
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
prompt: str | list[int],
|
||||||
|
mm_data: MultiModalDataDict,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
tokenization_kwargs: Mapping[str, object] | None = None,
|
||||||
|
mm_uuids: MultiModalUUIDDict | None = None,
|
||||||
|
) -> MultiModalInputs:
|
||||||
|
"""
|
||||||
|
Process multi-modal inputs to be used in vLLM.
|
||||||
|
|
||||||
|
Apply HF Processor on prompt text and multi-modal data together,
|
||||||
|
outputting token IDs and processed tensors.
|
||||||
|
"""
|
||||||
|
if tokenization_kwargs is None:
|
||||||
|
tokenization_kwargs = {}
|
||||||
|
|
||||||
|
mm_items = self._to_mm_items(mm_data)
|
||||||
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
|
if not isinstance(prompt, str):
|
||||||
|
# the prompt is the tokenized ids which is not supported
|
||||||
|
# by the hf_processor, which is why we would need to decode the ids
|
||||||
|
# into string
|
||||||
|
prompt = hf_processor.decode(prompt)
|
||||||
|
|
||||||
|
# Bypass cached processor and always apply to the full set of mm inputs
|
||||||
|
# NOTE: we can't just set caching=False because base class method
|
||||||
|
# transforms outputs to `MultiModalKwargs` which is not going to
|
||||||
|
# work for Transformers. We have a lot of logic tied to
|
||||||
|
# `mm_tokens_per_modality` below
|
||||||
|
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
|
||||||
|
prompt_text=prompt,
|
||||||
|
mm_items=mm_items,
|
||||||
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For gemma3 we check `token_type_ids` as the key
|
||||||
|
token_type_key = (
|
||||||
|
"mm_token_type_ids"
|
||||||
|
if "mm_token_type_ids" in processed_data
|
||||||
|
else "token_type_ids"
|
||||||
|
)
|
||||||
|
mm_token_type_ids = processed_data.pop(token_type_key)
|
||||||
|
|
||||||
|
# We can infer vLLM style placeholder from token type ids, if we split
|
||||||
|
# it for each input `mm_data`.
|
||||||
|
mm_positions = torch.where(mm_token_type_ids == 1)[1]
|
||||||
|
images = mm_items.get_items("image", ImageProcessorItems)
|
||||||
|
multimodal_config = self.info.ctx.model_config.multimodal_config
|
||||||
|
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
|
||||||
|
image_sizes = []
|
||||||
|
for item_idx in range(len(images)):
|
||||||
|
image_size = images.get_image_size(item_idx)
|
||||||
|
image_sizes.append((image_size.height, image_size.width))
|
||||||
|
|
||||||
|
mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
|
||||||
|
image_sizes=image_sizes, **mm_processor_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
mm_placeholders = {}
|
||||||
|
split_sizes = mm_tokens_per_modality["num_image_tokens"]
|
||||||
|
if split_sizes:
|
||||||
|
chunked_mm_positions = torch.split(mm_positions, split_sizes)
|
||||||
|
mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
|
||||||
|
chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
|
||||||
|
ranges = [
|
||||||
|
PlaceholderRange(
|
||||||
|
offset=positions[0].item(),
|
||||||
|
length=positions.shape[0],
|
||||||
|
is_embed=(mm_tokens == hf_processor.image_token_id).bool(),
|
||||||
|
)
|
||||||
|
for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens)
|
||||||
|
]
|
||||||
|
mm_placeholders = {"image": ranges}
|
||||||
|
|
||||||
|
processed_data["num_image_patches"] = torch.tensor(
|
||||||
|
mm_tokens_per_modality["num_image_patches"]
|
||||||
|
)
|
||||||
|
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
|
||||||
|
processed_data,
|
||||||
|
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use overrides if provided; fallback to data-dependent hashing.
|
||||||
|
mm_hashes = self._hash_mm_items(
|
||||||
|
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
|
||||||
|
)
|
||||||
|
|
||||||
|
return MultiModalInputs(
|
||||||
|
type="multimodal",
|
||||||
|
prompt_token_ids=prompt_ids,
|
||||||
|
mm_kwargs=mm_kwargs,
|
||||||
|
mm_hashes=mm_hashes,
|
||||||
|
mm_placeholders=mm_placeholders,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||||
|
supports_multimodal_raw_input_only = True
|
||||||
|
merge_by_field_config = True
|
||||||
|
# Backwards compatibility for prev released models. State dicts back then
|
||||||
|
# had different formats and cannot be loaded with `AutoModel` mapping as is
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
|
orig_to_new_prefix={
|
||||||
|
"language_model.model": "model.language_model",
|
||||||
|
"text_model.model": "model.text_model",
|
||||||
|
"vision_tower": "model.vision_tower",
|
||||||
|
"vqmodel": "model.vqmodel",
|
||||||
|
"visual": "model.visual",
|
||||||
|
"vision_model": "model.vision_model",
|
||||||
|
"vision_embed_tokens": "model.vision_embed_tokens",
|
||||||
|
"image_newline": "model.image_newline",
|
||||||
|
"multi_modal_projector": "model.multi_modal_projector",
|
||||||
|
"text_model.lm_head": "lm_head",
|
||||||
|
"language_model.lm_head": "lm_head",
|
||||||
|
# Qwen models used "model" as the name for the language model.
|
||||||
|
# Therefore, we must map each of submodule explicitly to avoid
|
||||||
|
# conflicts with newer models that use "model.language_model".
|
||||||
|
"model.embed_tokens": "model.language_model.embed_tokens",
|
||||||
|
"model.layers": "model.language_model.layers",
|
||||||
|
"model.norm": "model.language_model.norm",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
|
||||||
|
# Skip SupportsMRoPE.__init__ and call the next class in MRO
|
||||||
|
super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor | None,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: IntermediateTensors | None = None,
|
||||||
|
inputs_embeds: torch.Tensor | None = None,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> torch.Tensor | IntermediateTensors:
|
||||||
|
# Gemma3 and PaliGemma needs `token_type_ids` to work correctly
|
||||||
|
# Other models will not have `token_type_ids` in kwargs
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"}
|
||||||
|
model_output = super().forward(
|
||||||
|
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
|
||||||
|
)
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
|
"""Transformers backend multimodal classes do not contain a separate vLLM
|
||||||
|
language model class. Therefore, in order to return a language model vLLM class,
|
||||||
|
we use a wrapper to give `self` the same interface as a text model."""
|
||||||
|
|
||||||
|
# Exclude self and object
|
||||||
|
bases = self.__class__.mro()[1:-1]
|
||||||
|
# Keep only classes defined in `vllm.model_executor.models.transformers`
|
||||||
|
bases = [b for b in bases if ".transformers." in b.__module__]
|
||||||
|
# Exclude MultiModalMixin itself
|
||||||
|
bases = [b for b in bases if b is not MultiModalMixin]
|
||||||
|
|
||||||
|
class LanguageModel(*bases):
|
||||||
|
def __init__(self, multimodal_model):
|
||||||
|
# Don't call super().__init__() to avoid re-initialization
|
||||||
|
self.__dict__.update(multimodal_model.__dict__)
|
||||||
|
|
||||||
|
model = getattr_iter(self.model, ("language_model", "text_model"), None)
|
||||||
|
|
||||||
|
return LanguageModel(self)
|
||||||
|
|
||||||
|
def get_multimodal_embeddings(self, **kwargs):
|
||||||
|
pixel_values: torch.Tensor | None = kwargs.pop("pixel_values", None)
|
||||||
|
image_embeds: torch.Tensor | None = kwargs.pop("image_embeds", None)
|
||||||
|
# Model might use `image_patches` instead of `pixel_values`
|
||||||
|
if pixel_values is None:
|
||||||
|
pixel_values = kwargs.pop("image_patches", None)
|
||||||
|
|
||||||
|
if image_embeds is not None:
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
|
if pixel_values is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
num_image_patches = kwargs.pop("num_image_patches")
|
||||||
|
kwargs.pop("token_type_ids", None) # used only in `forward`
|
||||||
|
if pixel_values is not None:
|
||||||
|
vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
|
||||||
|
|
||||||
|
if isinstance(vision_embeddings, torch.Tensor):
|
||||||
|
if vision_embeddings.ndim == 2:
|
||||||
|
vision_embeddings = vision_embeddings.unsqueeze(0)
|
||||||
|
|
||||||
|
# Embeddings have to be 2D tensors of length `num_images`
|
||||||
|
# but transformers returns concat tensors if each patch
|
||||||
|
# is of different size. We split it back to make vLLM happy
|
||||||
|
vision_embeddings = torch.split(
|
||||||
|
vision_embeddings, num_image_patches.flatten().tolist()
|
||||||
|
)
|
||||||
|
vision_embeddings = [
|
||||||
|
embed.flatten(start_dim=0, end_dim=-2)
|
||||||
|
for embed in vision_embeddings
|
||||||
|
]
|
||||||
|
|
||||||
|
return vision_embeddings
|
||||||
|
|
||||||
|
def get_mrope_input_positions(
|
||||||
|
self,
|
||||||
|
input_tokens: list[int],
|
||||||
|
hf_config: "PretrainedConfig",
|
||||||
|
image_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||||
|
video_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||||
|
second_per_grid_ts: list[float] | None = None,
|
||||||
|
context_len: int = 0,
|
||||||
|
seq_len: int | None = None,
|
||||||
|
audio_feature_lengths: torch.Tensor | None = None,
|
||||||
|
use_audio_in_video: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, int]:
|
||||||
|
if any((second_per_grid_ts, audio_feature_lengths, use_audio_in_video)):
|
||||||
|
raise NotImplementedError("Transformers backend only supports images.")
|
||||||
|
|
||||||
|
if isinstance(image_grid_thw, list):
|
||||||
|
image_grid_thw = torch.tensor(image_grid_thw)
|
||||||
|
if isinstance(video_grid_thw, list):
|
||||||
|
video_grid_thw = torch.tensor(video_grid_thw)
|
||||||
|
|
||||||
|
mrope_positions, mrope_position_delta = self.model.get_rope_index(
|
||||||
|
input_ids=torch.tensor(input_tokens).unsqueeze(0),
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
)
|
||||||
|
|
||||||
|
mrope_positions = mrope_positions[:, 0, context_len:seq_len]
|
||||||
|
mrope_position_delta = mrope_position_delta[0].item()
|
||||||
|
|
||||||
|
return mrope_positions, mrope_position_delta
|
||||||
118
vllm/model_executor/models/transformers/pooling.py
Normal file
118
vllm/model_executor/models/transformers/pooling.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Copyright 2024 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Transformers backend mixins for pooling models."""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForSequenceClassification
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.pooler import (
|
||||||
|
ClassifierPooler,
|
||||||
|
CLSPool,
|
||||||
|
DispatchPooler,
|
||||||
|
Pooler,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||||
|
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingMixin(VllmModelForPooling):
|
||||||
|
default_pooling_type = "CLS"
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
|
||||||
|
# Skip VllmModelForPooling.__init__ and call the next class in MRO
|
||||||
|
super(VllmModelForPooling, self).__init__(
|
||||||
|
vllm_config=vllm_config, prefix=prefix
|
||||||
|
)
|
||||||
|
|
||||||
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
|
assert pooler_config is not None
|
||||||
|
|
||||||
|
self.pooler = DispatchPooler(
|
||||||
|
{
|
||||||
|
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||||
|
"embed": Pooler.for_embed(pooler_config),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
|
||||||
|
default_pooling_type = "CLS"
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
|
||||||
|
# Skip VllmModelForPooling.__init__ and call the next class in MRO
|
||||||
|
super(VllmModelForPooling, self).__init__(
|
||||||
|
vllm_config=vllm_config, prefix=prefix
|
||||||
|
)
|
||||||
|
|
||||||
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
|
assert pooler_config is not None
|
||||||
|
|
||||||
|
# Certain information about the the model and classifier can only be
|
||||||
|
# inferred from the `ForSequenceClassification` class. Therefore, we
|
||||||
|
# instantiate it on the "meta" device to avoid allocating GPU memory.
|
||||||
|
with torch.device("meta"):
|
||||||
|
seq_cls_model = AutoModelForSequenceClassification.from_config(
|
||||||
|
self.config,
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
# When used for sequence classification, some models have their
|
||||||
|
# pooling layers removed. Make sure this is reflected in vLLM.
|
||||||
|
for module in seq_cls_model.modules():
|
||||||
|
if hasattr(module, "pooler") and module.pooler is None:
|
||||||
|
self.model.pooler = None
|
||||||
|
break
|
||||||
|
if self.model.pooler is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Sequence classification models with pooling layers are not "
|
||||||
|
"supported yet in the Transformers backend."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unlike `lm_head`, `classifier` is not always `nn.Linear`.
|
||||||
|
self.classifier = seq_cls_model.classifier
|
||||||
|
self.init_parameters(self.classifier, dtype=self.model_config.head_dtype)
|
||||||
|
|
||||||
|
class ClassifierWithReshape(self.classifier.__class__):
|
||||||
|
"""CLSPool has already been applied in `pooling`.
|
||||||
|
Add dim to match expected input shape of `classifier.forward`."""
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if len(args) > 0:
|
||||||
|
args = (args[0].unsqueeze(1), *args[1:])
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
self.classifier.__class__ = ClassifierWithReshape
|
||||||
|
|
||||||
|
self.pooler = DispatchPooler(
|
||||||
|
{
|
||||||
|
"token_classify": Pooler.for_token_classify(
|
||||||
|
pooler_config, classifier=self.classifier
|
||||||
|
),
|
||||||
|
"classify": ClassifierPooler(
|
||||||
|
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
|
||||||
|
),
|
||||||
|
"score": ClassifierPooler(
|
||||||
|
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
207
vllm/model_executor/models/transformers/utils.py
Normal file
207
vllm/model_executor/models/transformers/utils.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Copyright 2024 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Transformers backend utilities."""
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.config.utils import getattr_iter
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import (
|
||||||
|
ColumnParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from `accelerate`
|
||||||
|
@contextmanager
|
||||||
|
def init_on_device_without_buffers(device: torch.device):
|
||||||
|
"""
|
||||||
|
A context manager under which models are initialized with all
|
||||||
|
parameters on the specified device. However buffers are not
|
||||||
|
initialized on specified device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (`torch.device`):
|
||||||
|
Device to initialize all parameters on.
|
||||||
|
"""
|
||||||
|
|
||||||
|
old_register_parameter = nn.Module.register_parameter
|
||||||
|
|
||||||
|
def register_empty_parameter(module, name, param):
|
||||||
|
old_register_parameter(module, name, param)
|
||||||
|
if param is not None:
|
||||||
|
param_cls = type(module._parameters[name])
|
||||||
|
kwargs = module._parameters[name].__dict__
|
||||||
|
kwargs["requires_grad"] = param.requires_grad
|
||||||
|
module._parameters[name] = param_cls(
|
||||||
|
module._parameters[name].to(device), **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor_constructors_to_patch = {}
|
||||||
|
|
||||||
|
def patch_tensor_constructor(fn):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
kwargs["device"] = device
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
try:
|
||||||
|
nn.Module.register_parameter = register_empty_parameter
|
||||||
|
for torch_function_name in tensor_constructors_to_patch:
|
||||||
|
setattr(
|
||||||
|
torch,
|
||||||
|
torch_function_name,
|
||||||
|
patch_tensor_constructor(getattr(torch, torch_function_name)),
|
||||||
|
)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
nn.Module.register_parameter = old_register_parameter
|
||||||
|
for (
|
||||||
|
torch_function_name,
|
||||||
|
old_torch_function,
|
||||||
|
) in tensor_constructors_to_patch.items():
|
||||||
|
setattr(torch, torch_function_name, old_torch_function)
|
||||||
|
|
||||||
|
|
||||||
|
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
|
||||||
|
|
||||||
|
|
||||||
|
def replace_linear_class(
|
||||||
|
linear: nn.Linear,
|
||||||
|
style: Style = "replicate",
|
||||||
|
quant_config: "QuantizationConfig | None" = None,
|
||||||
|
*,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> ColumnParallelLinear | RowParallelLinear | ReplicatedLinear:
|
||||||
|
"""
|
||||||
|
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
linear: `nn.Linear` to be replaced.
|
||||||
|
style: Tensor parallel style of the new linear, e.g. "colwise".
|
||||||
|
quant_config: Quantization config for the new linear.
|
||||||
|
Returns:
|
||||||
|
The new linear.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(style, str):
|
||||||
|
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
|
||||||
|
|
||||||
|
vllm_linear_cls, vllm_linear_kwargs = {
|
||||||
|
"colwise": (ColumnParallelLinear, {}),
|
||||||
|
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
|
||||||
|
"rowwise": (RowParallelLinear, {}),
|
||||||
|
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
|
||||||
|
"replicate": (ReplicatedLinear, {}),
|
||||||
|
}.get(style, (ReplicatedLinear, {}))
|
||||||
|
|
||||||
|
return vllm_linear_cls(
|
||||||
|
input_size=linear.in_features,
|
||||||
|
output_size=linear.out_features,
|
||||||
|
bias=linear.bias is not None,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix,
|
||||||
|
return_bias=False,
|
||||||
|
**vllm_linear_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
|
||||||
|
"""Replace a Transformers RMSNorm with vLLM's RMSNorm.
|
||||||
|
|
||||||
|
This method assumes:
|
||||||
|
- Weight is stored as `weight`.
|
||||||
|
- Epsilon is stored as `eps` or `variance_epsilon`.
|
||||||
|
- `with_scale` indicates whether the layer has a weight (Gemma3n only).
|
||||||
|
- `var_hidden_size` is only ever used for Intern vision encoder in vLLM
|
||||||
|
and Transformers doesn't appear to have the same concept.
|
||||||
|
"""
|
||||||
|
eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6)
|
||||||
|
kwargs = {"hidden_size": hidden_size, "eps": eps}
|
||||||
|
# Update hidden size if weight is available
|
||||||
|
weight_meta = getattr(rms_norm, "weight", None)
|
||||||
|
if weight_meta is not None:
|
||||||
|
kwargs["hidden_size"] = weight_meta.size(0)
|
||||||
|
# Check if weight is all zeros, which indicates GemmaRMSNorm
|
||||||
|
# We must create a new instance because rms_norm is on meta
|
||||||
|
try:
|
||||||
|
with torch.device("cpu"):
|
||||||
|
weight_test = getattr(rms_norm.__class__(1), "weight", None)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to determine if RMSNorm weight is centered on zero or one. "
|
||||||
|
"Defaulting to one."
|
||||||
|
)
|
||||||
|
weight_test = None
|
||||||
|
if weight_test is not None and torch.all(weight_test == 0):
|
||||||
|
return GemmaRMSNorm(**kwargs)
|
||||||
|
# Otherwise assume it's a regular RMSNorm
|
||||||
|
kwargs["has_weight"] = getattr(rms_norm, "with_scale", True)
|
||||||
|
if weight_meta is not None:
|
||||||
|
kwargs["dtype"] = weight_meta.dtype
|
||||||
|
else:
|
||||||
|
# No weight, fall back to weightless RMSNorm
|
||||||
|
kwargs["has_weight"] = False
|
||||||
|
return RMSNorm(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
|
||||||
|
logger.debug("%s: %s -> %s", name, old_module, new_module)
|
||||||
|
|
||||||
|
|
||||||
|
def get_feature_request_tip(
|
||||||
|
model: str,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
) -> str:
|
||||||
|
hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new"
|
||||||
|
gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose"
|
||||||
|
url = hf_url if trust_remote_code else gh_url
|
||||||
|
prefix = f"Please open {url} to request support for this feature. "
|
||||||
|
if Path(model).exists():
|
||||||
|
prefix = ""
|
||||||
|
doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models"
|
||||||
|
tip = f"See {doc_url} for instructions on how to add support yourself."
|
||||||
|
return f"{prefix}{tip}"
|
||||||
|
|
||||||
|
|
||||||
|
def can_enable_torch_compile(vllm_config: "VllmConfig") -> bool:
|
||||||
|
"""
|
||||||
|
Callable to be passed to `@support_torch_compile`'s `enable_if` argument.
|
||||||
|
|
||||||
|
Defaults to `True` but is disabled in the following situations:
|
||||||
|
|
||||||
|
- The model uses dynamic rope scaling.
|
||||||
|
"""
|
||||||
|
text_config = vllm_config.model_config.hf_config.get_text_config()
|
||||||
|
# Dynamic rope scaling is not compatible with torch.compile
|
||||||
|
rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
|
||||||
|
return rope_scaling.get("rope_type") != "dynamic"
|
||||||
@@ -1,215 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
# Copyright 2024 The vLLM team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Wrapper around `transformers` models for pooling tasks."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import AutoModelForSequenceClassification
|
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionType
|
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
from vllm.model_executor.layers.pooler import (
|
|
||||||
ClassifierPooler,
|
|
||||||
CLSPool,
|
|
||||||
DispatchPooler,
|
|
||||||
Pooler,
|
|
||||||
)
|
|
||||||
from vllm.sequence import IntermediateTensors
|
|
||||||
|
|
||||||
from .interfaces_base import VllmModelForPooling
|
|
||||||
from .transformers import TransformersBase, can_enable_torch_compile
|
|
||||||
from .transformers_moe import TransformersMoEBase
|
|
||||||
from .utils import WeightsMapper
|
|
||||||
|
|
||||||
|
|
||||||
class TransformersPoolingBase(TransformersBase, VllmModelForPooling):
|
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
|
||||||
# These are applied in order, so the order matters!
|
|
||||||
orig_to_new_prefix={
|
|
||||||
# Handle BERT-like models
|
|
||||||
"roberta": "model",
|
|
||||||
"bert": "model",
|
|
||||||
# Add `model.` prefix for base model checkpoints
|
|
||||||
"": "model.",
|
|
||||||
# Remove `model.` prefix if it was already there
|
|
||||||
"model.model.": "model.",
|
|
||||||
# Classifier/scoring heads will be adjacent to `model`
|
|
||||||
"model.score": "classifier",
|
|
||||||
"model.classifier": "classifier",
|
|
||||||
},
|
|
||||||
orig_to_new_suffix={
|
|
||||||
# Replace legacy suffixes used for norms
|
|
||||||
".gamma": ".weight",
|
|
||||||
".beta": ".bias",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
|
||||||
|
|
||||||
# Skip unsupported/unwanted output embeddings layers
|
|
||||||
self.skip_prefixes.extend(
|
|
||||||
[
|
|
||||||
"model.lm_head.",
|
|
||||||
"model.predictions.",
|
|
||||||
"model.qa_outputs.",
|
|
||||||
"model.embeddings_project.",
|
|
||||||
"model.discriminator_predictions.",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Some encoder models have the position_ids buffer in the checkpoint.
|
|
||||||
# vLLM will always pass position_ids as an argument, so we skip loading
|
|
||||||
# the buffer if it exists
|
|
||||||
self.skip_substrs.append("position_ids")
|
|
||||||
|
|
||||||
# Some encoder models have the bias of the final classifier layer
|
|
||||||
# in the checkpoint. vLLM does not use this bias, so we skip loading
|
|
||||||
# it if it exists
|
|
||||||
self.skip_substrs.append("score.bias")
|
|
||||||
|
|
||||||
# roberta-like models an extra padding in positions.
|
|
||||||
# FIXME(Isotr0py): This is quite hacky for roberta edge case,
|
|
||||||
# we should find a better way to handle this.
|
|
||||||
self.is_roberta = "roberta" in self.text_config.model_type
|
|
||||||
self.padding_idx = self.text_config.pad_token_id
|
|
||||||
|
|
||||||
def create_attention_instances(
|
|
||||||
self, attn_type: AttentionType = AttentionType.DECODER
|
|
||||||
) -> dict[int, Attention]:
|
|
||||||
# TODO(hmellor): Better way to detect encoder models
|
|
||||||
# In encoder models, the attention layers will have `is_causal=False`
|
|
||||||
is_encoder = lambda m: not getattr(m, "is_causal", True)
|
|
||||||
# vLLM does not support encoder-decoder models, so if any encoder layer
|
|
||||||
# is found, we assume the whole model is an encoder model
|
|
||||||
if any(is_encoder(m) for m in self.model.modules()):
|
|
||||||
attn_type = AttentionType.ENCODER_ONLY
|
|
||||||
|
|
||||||
# Check minimum transformers version for encoder models support
|
|
||||||
if attn_type == AttentionType.ENCODER_ONLY:
|
|
||||||
self.check_version("4.57.0.dev0", "encoder models support")
|
|
||||||
|
|
||||||
return super().create_attention_instances(attn_type)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor | None,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
intermediate_tensors: IntermediateTensors | None = None,
|
|
||||||
inputs_embeds: torch.Tensor | None = None,
|
|
||||||
) -> torch.Tensor | IntermediateTensors:
|
|
||||||
if self.is_roberta:
|
|
||||||
# RoBERTa-specific positions padding
|
|
||||||
positions += self.padding_idx + 1
|
|
||||||
return super().forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
positions=positions,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
|
||||||
class TransformersEmbeddingModel(TransformersPoolingBase):
|
|
||||||
default_pooling_type = "CLS"
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
|
||||||
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
|
||||||
assert pooler_config is not None
|
|
||||||
|
|
||||||
self.pooler = DispatchPooler(
|
|
||||||
{
|
|
||||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
|
||||||
"embed": Pooler.for_embed(pooler_config),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
|
||||||
class TransformersForSequenceClassification(TransformersPoolingBase):
|
|
||||||
default_pooling_type = "CLS"
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
|
||||||
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
|
||||||
assert pooler_config is not None
|
|
||||||
|
|
||||||
# Certain information about the the model and classifier can only be
|
|
||||||
# inferred from the `ForSequenceClassification` class. Therefore, we
|
|
||||||
# instantiate it on the "meta" device to avoid allocating GPU memory.
|
|
||||||
with torch.device("meta"):
|
|
||||||
seq_cls_model = AutoModelForSequenceClassification.from_config(
|
|
||||||
self.config,
|
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
trust_remote_code=self.model_config.trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
# When used for sequence classification, some models have their
|
|
||||||
# pooling layers removed. Make sure this is reflected in vLLM.
|
|
||||||
for module in seq_cls_model.modules():
|
|
||||||
if hasattr(module, "pooler") and module.pooler is None:
|
|
||||||
self.model.pooler = None
|
|
||||||
break
|
|
||||||
if self.model.pooler is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"Sequence classification models with pooling layers are not "
|
|
||||||
"supported yet in the Transformers backend."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Unlike `lm_head`, `classifier` is not always `nn.Linear`.
|
|
||||||
self.classifier = seq_cls_model.classifier
|
|
||||||
self.init_parameters(self.classifier, dtype=self.model_config.head_dtype)
|
|
||||||
|
|
||||||
class ClassifierWithReshape(self.classifier.__class__):
|
|
||||||
"""CLSPool has already been applied in `pooling`.
|
|
||||||
Add dim to match expected input shape of `classifier.forward`."""
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
if len(args) > 0:
|
|
||||||
args = (args[0].unsqueeze(1), *args[1:])
|
|
||||||
return super().forward(*args, **kwargs)
|
|
||||||
|
|
||||||
self.classifier.__class__ = ClassifierWithReshape
|
|
||||||
|
|
||||||
self.pooler = DispatchPooler(
|
|
||||||
{
|
|
||||||
"token_classify": Pooler.for_token_classify(
|
|
||||||
pooler_config, classifier=self.classifier
|
|
||||||
),
|
|
||||||
"classify": ClassifierPooler(
|
|
||||||
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
|
|
||||||
),
|
|
||||||
"score": ClassifierPooler(
|
|
||||||
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
|
||||||
class TransformersMoEEmbeddingModel(TransformersMoEBase, TransformersEmbeddingModel):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
|
||||||
class TransformersMoEForSequenceClassification(
|
|
||||||
TransformersMoEBase, TransformersForSequenceClassification
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
Reference in New Issue
Block a user