mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
[Model][0/N] Improve all pooling task | clean up (#25817)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -581,7 +581,7 @@ These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode)
|
||||
| `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | | ✅︎ |
|
||||
|
||||
!!! note
|
||||
Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>.
|
||||
Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner_client.py>.
|
||||
|
||||
[](){ #supported-mm-models }
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ python examples/online_serving/pooling/jinaai_rerank_client.py
|
||||
## Named Entity Recognition (NER) usage
|
||||
|
||||
```bash
|
||||
python examples/online_serving/pooling/ner.py
|
||||
python examples/online_serving/pooling/ner_client.py
|
||||
```
|
||||
|
||||
## Openai chat embedding for multimodal usage
|
||||
|
||||
@@ -8,6 +8,8 @@ import os
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.envs import maybe_convert_bool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
VLLM_CI_NO_SKIP: bool = False
|
||||
VLLM_CI_DTYPE: str | None = None
|
||||
@@ -25,6 +27,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_CI_HEAD_DTYPE": lambda: os.getenv("VLLM_CI_HEAD_DTYPE", None),
|
||||
# Allow changing the head dtype used by transformers in tests
|
||||
"VLLM_CI_HF_DTYPE": lambda: os.getenv("VLLM_CI_HF_DTYPE", None),
|
||||
# Allow control over whether tests use enforce_eager
|
||||
"VLLM_CI_ENFORCE_EAGER": lambda: maybe_convert_bool(
|
||||
os.getenv("VLLM_CI_ENFORCE_EAGER", None)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -58,7 +58,9 @@ def test_pooling_params(llm: LLM):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_encode_api(llm: LLM):
|
||||
# chunked prefill does not support all pooling
|
||||
err_msg = "pooling_task must be one of.+"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.encode(prompts, use_tqdm=False)
|
||||
|
||||
@@ -35,7 +35,6 @@ def llm():
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_pooling_params(llm: LLM):
|
||||
def get_outputs(normalize):
|
||||
outputs = llm.embed(
|
||||
|
||||
@@ -74,7 +74,6 @@ def test_multiple_pooling_params(llm: LLM):
|
||||
assert len(PROMPTS) == len(outputs)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_right_side_truncation(llm: LLM):
|
||||
# Embeddings models should truncate the end of the prompt
|
||||
tokenizer = llm.get_tokenizer()
|
||||
|
||||
@@ -33,7 +33,6 @@ def llm():
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_pooling_params(llm: LLM):
|
||||
def get_outputs(activation):
|
||||
text_1 = "What is the capital of France?"
|
||||
|
||||
@@ -3,12 +3,15 @@
|
||||
# Adapted from https://huggingface.co/docs/transformers/perplexity
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
import tests.ci_envs as ci_envs
|
||||
from tests.models.utils import GenerateModelInfo, TokensTextLogprobsPromptLogprobs
|
||||
from tests.models.utils import (
|
||||
GenerateModelInfo,
|
||||
TokensTextLogprobsPromptLogprobs,
|
||||
get_vllm_extra_kwargs,
|
||||
)
|
||||
from vllm.logprobs import Logprob
|
||||
|
||||
# See #24485
|
||||
@@ -25,27 +28,10 @@ def wikitext_ppl_test(
|
||||
vllm_extra_kwargs=None,
|
||||
atol=PPL_TOL,
|
||||
):
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
|
||||
pytest.skip("Skipping test.")
|
||||
vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs)
|
||||
|
||||
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
||||
|
||||
# Allow vllm to test using the given dtype, such as float32
|
||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||
vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
|
||||
|
||||
# Allow vllm to test using hf_overrides
|
||||
if model_info.hf_overrides is not None:
|
||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||
|
||||
# Allow changing the head dtype used by vllm in tests
|
||||
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
|
||||
if "hf_overrides" not in vllm_extra_kwargs:
|
||||
vllm_extra_kwargs["hf_overrides"] = {}
|
||||
vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
|
||||
|
||||
with vllm_runner(
|
||||
model_info.name,
|
||||
gpu_memory_utilization=0.7,
|
||||
|
||||
47
tests/models/language/pooling/test_head_dtype.py
Normal file
47
tests/models/language/pooling/test_head_dtype.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["nie3e/sentiment-polish-gpt2-small"],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_classify_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
with hf_runner(
|
||||
model, dtype=dtype, auto_cls=AutoModelForSequenceClassification
|
||||
) as hf_model:
|
||||
hf_outputs = hf_model.classify(example_prompts)
|
||||
|
||||
for head_dtype_str in ["float32", "model"]:
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
hf_overrides={"head_dtype": head_dtype_str},
|
||||
) as vllm_model:
|
||||
model_config = vllm_model.llm.llm_engine.model_config
|
||||
model_dtype = model_config.dtype
|
||||
head_dtype = model_config.head_dtype
|
||||
|
||||
if head_dtype_str == "float32":
|
||||
assert head_dtype == torch.float32
|
||||
elif head_dtype_str == "model":
|
||||
assert head_dtype == model_dtype
|
||||
|
||||
vllm_outputs = vllm_model.classify(example_prompts)
|
||||
|
||||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||
hf_output = torch.tensor(hf_output).float()
|
||||
vllm_output = torch.tensor(vllm_output).float()
|
||||
|
||||
assert torch.allclose(hf_output, vllm_output, atol=1e-2)
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
import types
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -14,11 +13,12 @@ from vllm.model_executor.models.bert import (
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# 1) Functional test: SPLADE formula correctness (no HF download needed)
|
||||
# Functional test: SPLADE formula correctness (no HF download needed)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("B,T,H,V", [(2, 3, 5, 7)])
|
||||
@torch.inference_mode
|
||||
def test_splade_pooler_matches_reference_formula(B, T, H, V):
|
||||
"""Ensure SPLADESparsePooler forward() matches the mathematical formula:
|
||||
log1p(relu(logits)) -> max over sequence length (after masking)."""
|
||||
@@ -26,9 +26,11 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V):
|
||||
|
||||
# Prepare [B] sequences of shape [T, H]
|
||||
hs_list = [torch.randn(T, H) for _ in range(B)]
|
||||
hs_tenser = torch.cat(hs_list)
|
||||
|
||||
# Simulate PoolingMetadata (only required fields)
|
||||
prompt_lens = [T, T - 1]
|
||||
prompt_lens_tenser = torch.tensor(prompt_lens, dtype=torch.int32)
|
||||
token_ids = torch.tensor(
|
||||
[
|
||||
[101, 5, 102], # Batch 0: [CLS], token, [SEP]
|
||||
@@ -36,7 +38,9 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V):
|
||||
],
|
||||
dtype=torch.long,
|
||||
)
|
||||
meta = types.SimpleNamespace(prompt_lens=prompt_lens, prompt_token_ids=token_ids)
|
||||
meta = types.SimpleNamespace(
|
||||
prompt_lens=prompt_lens_tenser, prompt_token_ids=token_ids
|
||||
)
|
||||
|
||||
# MLM head (prefer BertMLMHead, fallback to Linear if unavailable)
|
||||
try:
|
||||
@@ -46,10 +50,10 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V):
|
||||
|
||||
# Forward pass through SPLADE pooler
|
||||
pooler = SPLADESparsePooler(mlm_head=mlm_head, pooling="max", remove_cls_sep=True)
|
||||
pooled = pooler(hidden_states=hs_list, pooling_metadata=meta) # list of [V]
|
||||
pooled = pooler(hidden_states=hs_tenser, pooling_metadata=meta) # list of [V]
|
||||
|
||||
# Basic output checks
|
||||
assert isinstance(pooled, list) and len(pooled) == B
|
||||
assert isinstance(pooled, torch.Tensor) and len(pooled) == B
|
||||
for vec in pooled:
|
||||
assert vec.shape == (V,)
|
||||
assert torch.isfinite(vec).all()
|
||||
@@ -83,40 +87,3 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V):
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# 2) Integration smoke test: end-to-end embedding path wiring
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.cpu_model
|
||||
def test_bert_splade_sparse_embed_smoke(vllm_runner, monkeypatch):
|
||||
"""Ensure BertSpladeSparseEmbeddingModel loads and produces sparse embeddings."""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
MODEL_ID = "hf-internal-testing/tiny-random-bert"
|
||||
hf_overrides = {"architectures": ["BertSpladeSparseEmbeddingModel"]}
|
||||
|
||||
# Enforce CPU-only execution (optional)
|
||||
monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "")
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
vocab_size = tok.vocab_size
|
||||
|
||||
# The embed path should route through SPLADESparsePooler
|
||||
with vllm_runner(
|
||||
MODEL_ID,
|
||||
runner="pooling",
|
||||
max_model_len=64,
|
||||
hf_overrides=hf_overrides,
|
||||
) as vm:
|
||||
outs = vm.embed(["hello world", "splade sparse test"])
|
||||
|
||||
# Basic sanity checks
|
||||
assert len(outs) == 2
|
||||
assert outs[0].shape[0] == vocab_size
|
||||
assert outs[1].shape[0] == vocab_size
|
||||
assert np.isfinite(outs[0]).all() and (outs[0] >= 0).all()
|
||||
assert np.isfinite(outs[1]).all() and (outs[1] >= 0).all()
|
||||
|
||||
@@ -6,12 +6,16 @@ from collections.abc import Sequence
|
||||
|
||||
import mteb
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
|
||||
import tests.ci_envs as ci_envs
|
||||
from tests.models.utils import EmbedModelInfo, RerankModelInfo, check_embeddings_close
|
||||
from tests.models.utils import (
|
||||
EmbedModelInfo,
|
||||
RerankModelInfo,
|
||||
check_embeddings_close,
|
||||
get_vllm_extra_kwargs,
|
||||
)
|
||||
|
||||
# Most embedding models on the STS12 task (See #17175):
|
||||
# - Model implementation and minor changes in tensor dtype
|
||||
@@ -165,28 +169,11 @@ def mteb_test_embed_models(
|
||||
hf_model_callback=None,
|
||||
atol=MTEB_EMBED_TOL,
|
||||
):
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
|
||||
pytest.skip("Skipping test.")
|
||||
vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs)
|
||||
|
||||
# Test embed_dims, isnan and whether to use normalize
|
||||
example_prompts = ["The chef prepared a delicious meal." * 1000]
|
||||
|
||||
# Allow vllm to test using the given dtype, such as float32
|
||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||
vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
|
||||
|
||||
# Allow vllm to test using hf_overrides
|
||||
if model_info.hf_overrides is not None:
|
||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||
|
||||
# Allow changing the head dtype used by vllm in tests
|
||||
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
|
||||
if "hf_overrides" not in vllm_extra_kwargs:
|
||||
vllm_extra_kwargs["hf_overrides"] = {}
|
||||
vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
|
||||
|
||||
with vllm_runner(
|
||||
model_info.name,
|
||||
runner="pooling",
|
||||
@@ -212,9 +199,12 @@ def mteb_test_embed_models(
|
||||
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
|
||||
head_dtype = model_config.head_dtype
|
||||
|
||||
# Test embed_dims, isnan and whether to use normalize
|
||||
# Test embedding_size, isnan and whether to use normalize
|
||||
vllm_outputs = vllm_model.embed(example_prompts, truncate_prompt_tokens=-1)
|
||||
assert not torch.any(torch.isnan(torch.tensor(vllm_outputs)))
|
||||
outputs_tensor = torch.tensor(vllm_outputs)
|
||||
assert not torch.any(torch.isnan(outputs_tensor))
|
||||
embedding_size = model_config.embedding_size
|
||||
assert torch.tensor(vllm_outputs).shape[-1] == embedding_size
|
||||
|
||||
# Accelerate mteb test by setting
|
||||
# SentenceTransformers mteb score to a constant
|
||||
@@ -231,7 +221,7 @@ def mteb_test_embed_models(
|
||||
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
|
||||
st_dtype = next(hf_model.model.parameters()).dtype
|
||||
|
||||
# Test embed_dims and whether to use normalize
|
||||
# Check embeddings close to hf outputs
|
||||
hf_outputs = hf_model.encode(example_prompts)
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
@@ -323,24 +313,7 @@ def mteb_test_rerank_models(
|
||||
vllm_mteb_encoder=VllmMtebEncoder,
|
||||
atol=MTEB_RERANK_TOL,
|
||||
):
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
|
||||
pytest.skip("Skipping test.")
|
||||
|
||||
# Allow vllm to test using the given dtype, such as float32
|
||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||
vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
|
||||
|
||||
# Allow vllm to test using hf_overrides
|
||||
if model_info.hf_overrides is not None:
|
||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||
|
||||
# Allow changing the head dtype used by vllm in tests
|
||||
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
|
||||
if "hf_overrides" not in vllm_extra_kwargs:
|
||||
vllm_extra_kwargs["hf_overrides"] = {}
|
||||
vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
|
||||
vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs)
|
||||
|
||||
with vllm_runner(
|
||||
model_info.name,
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
|
||||
from vllm.multimodal.processing import InputProcessingContext
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
|
||||
from .. import ci_envs
|
||||
from .registry import HF_EXAMPLE_MODELS
|
||||
|
||||
TokensText = tuple[list[int], str]
|
||||
@@ -414,6 +415,35 @@ class GenerateModelInfo(ModelInfo):
|
||||
hf_ppl: float | None = None
|
||||
|
||||
|
||||
def get_vllm_extra_kwargs(model_info: ModelInfo, vllm_extra_kwargs):
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
|
||||
import pytest
|
||||
|
||||
pytest.skip("Skipping test.")
|
||||
|
||||
# Allow vllm to test using the given dtype, such as float32
|
||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||
vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
|
||||
|
||||
# Allow vllm to test using hf_overrides
|
||||
if model_info.hf_overrides is not None:
|
||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||
|
||||
# Allow changing the head dtype used by vllm in tests
|
||||
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
|
||||
if "hf_overrides" not in vllm_extra_kwargs:
|
||||
vllm_extra_kwargs["hf_overrides"] = {}
|
||||
vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
|
||||
|
||||
# Allow control over whether tests use enforce_eager
|
||||
if ci_envs.VLLM_CI_ENFORCE_EAGER is not None:
|
||||
vllm_extra_kwargs["enforce_eager"] = ci_envs.VLLM_CI_ENFORCE_EAGER
|
||||
|
||||
return vllm_extra_kwargs
|
||||
|
||||
|
||||
def dummy_hf_overrides(
|
||||
hf_config: PretrainedConfig,
|
||||
*,
|
||||
|
||||
@@ -30,6 +30,7 @@ from vllm.transformers_utils.config import (
|
||||
get_sentence_transformer_tokenizer_config,
|
||||
is_encoder_decoder,
|
||||
is_interleaved,
|
||||
try_get_dense_modules,
|
||||
try_get_generation_config,
|
||||
try_get_safetensors_metadata,
|
||||
try_get_tokenizer_config,
|
||||
@@ -1681,6 +1682,20 @@ class ModelConfig:
|
||||
logger.debug_once("head dtype: %s", head_dtype)
|
||||
return head_dtype
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
if hasattr(self.hf_config, "hidden_size"):
|
||||
return self.hf_config.hidden_size
|
||||
text_config = self.hf_config.get_text_config()
|
||||
return text_config.hidden_size
|
||||
|
||||
@property
|
||||
def embedding_size(self):
|
||||
dense_modules = try_get_dense_modules(self.model, revision=self.revision)
|
||||
if dense_modules is not None:
|
||||
return dense_modules[-1]["out_features"]
|
||||
return self.hidden_size
|
||||
|
||||
def get_and_verify_max_len(self, max_model_len: int):
|
||||
# Consider max_model_len in tokenizer_config only when
|
||||
# pooling models use absolute position_embedding.
|
||||
|
||||
@@ -13,7 +13,10 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.models.config import VerifyAndUpdateConfig
|
||||
from vllm.transformers_utils.config import get_hf_file_bytes, get_hf_file_to_dict
|
||||
from vllm.transformers_utils.config import (
|
||||
get_hf_file_bytes,
|
||||
try_get_dense_modules,
|
||||
)
|
||||
|
||||
from .interfaces_base import VllmModelForPooling, is_pooling_model
|
||||
|
||||
@@ -35,43 +38,25 @@ _GENERATE_SUFFIXES = [
|
||||
def _load_st_projector(model_config: "ModelConfig") -> nn.Module | None:
|
||||
"""Load Sentence-Transformers Dense projection layers."""
|
||||
|
||||
dense_modules = try_get_dense_modules(
|
||||
model_config.model, revision=model_config.revision
|
||||
)
|
||||
|
||||
if dense_modules is None:
|
||||
return
|
||||
|
||||
try:
|
||||
modules = get_hf_file_to_dict(
|
||||
"modules.json", model_config.model, model_config.revision
|
||||
)
|
||||
if not modules:
|
||||
return None
|
||||
|
||||
if isinstance(modules, dict):
|
||||
modules = modules.get("modules", [])
|
||||
|
||||
dense_modules = [
|
||||
m for m in modules if m.get("type") == "sentence_transformers.models.Dense"
|
||||
]
|
||||
if not dense_modules:
|
||||
return None
|
||||
|
||||
layers = []
|
||||
for module in dense_modules:
|
||||
folder = module.get("path", "")
|
||||
|
||||
config_path = f"{folder}/config.json" if folder else "config.json"
|
||||
layer_config = get_hf_file_to_dict(
|
||||
config_path, model_config.model, model_config.revision
|
||||
)
|
||||
if not layer_config:
|
||||
continue
|
||||
|
||||
for layer_config in dense_modules:
|
||||
folder = layer_config["folder"]
|
||||
linear = nn.Linear(
|
||||
layer_config.get("in_features", 768),
|
||||
layer_config.get("out_features", 768),
|
||||
layer_config["in_features"],
|
||||
layer_config["out_features"],
|
||||
bias=layer_config.get("bias", True),
|
||||
dtype=model_config.head_dtype,
|
||||
)
|
||||
|
||||
if not _load_dense_weights(linear, folder, model_config):
|
||||
continue
|
||||
|
||||
layers.append(linear)
|
||||
if act_name := layer_config.get("activation_function"):
|
||||
layers.append(get_act_fn(act_name))
|
||||
@@ -303,18 +288,18 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import get_model_hidden_size, maybe_prefix
|
||||
from .utils import maybe_prefix
|
||||
|
||||
class ModelForSequenceClassification(
|
||||
_create_pooling_model_cls(cls), SupportsCrossEncoding
|
||||
):
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
quant_config = vllm_config.quant_config
|
||||
hidden_size = get_model_hidden_size(config)
|
||||
|
||||
self.score = ReplicatedLinear(
|
||||
hidden_size,
|
||||
model_config.hidden_size,
|
||||
config.num_labels,
|
||||
bias=False,
|
||||
params_dtype=torch.float32,
|
||||
|
||||
@@ -50,7 +50,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from ..layers.pooler import DispatchPooler, Pooler
|
||||
from .interfaces import SupportsPP
|
||||
from .interfaces import SupportsCrossEncoding, SupportsPP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
is_pp_missing_parameter,
|
||||
@@ -321,7 +321,7 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
class GPT2ForSequenceClassification(nn.Module):
|
||||
class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
"""GPT2 Model for sequence classification.
|
||||
|
||||
This class expands GPT2Model with pooling and score functions - last token
|
||||
@@ -358,6 +358,9 @@ class GPT2ForSequenceClassification(nn.Module):
|
||||
}
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.get_input_embeddings(input_ids)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@@ -148,37 +148,6 @@ class GritLMMeanPool(nn.Module):
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
|
||||
def forward_one(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_len: torch.Tensor | None = None,
|
||||
instr_len: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert prompt_len is None or prompt_len == hidden_states.shape[0], (
|
||||
"partial prefill not supported with MEAN pooling"
|
||||
)
|
||||
|
||||
return hidden_states[instr_len:].mean(dim=0, dtype=torch.float32)
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_lens: torch.Tensor,
|
||||
instr_lens: torch.Tensor,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
offset = 0
|
||||
pooled_data = list[torch.Tensor]()
|
||||
|
||||
for prompt_len, instr_len in zip(prompt_lens, instr_lens):
|
||||
pooled_data.append(
|
||||
hidden_states[offset + instr_len : offset + prompt_len].mean(
|
||||
dim=0, dtype=torch.float32
|
||||
)
|
||||
)
|
||||
offset += prompt_len
|
||||
|
||||
return pooled_data
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
@@ -190,18 +159,20 @@ class GritLMMeanPool(nn.Module):
|
||||
self._get_instruction_len(token_ids.cpu().numpy())
|
||||
for token_ids in get_prompt_token_ids(pooling_metadata)
|
||||
],
|
||||
device=prompt_lens.device,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
if isinstance(hidden_states, list):
|
||||
return [
|
||||
self.forward_one(h, prompt_len, instr_len)
|
||||
for h, prompt_len, instr_len in zip(
|
||||
hidden_states, prompt_lens, instr_lens
|
||||
offset = 0
|
||||
pooled_data = list[torch.Tensor]()
|
||||
for prompt_len, instr_len in zip(prompt_lens, instr_lens):
|
||||
pooled_data.append(
|
||||
hidden_states[offset + instr_len : offset + prompt_len].mean(
|
||||
dim=0, dtype=torch.float32
|
||||
)
|
||||
]
|
||||
)
|
||||
offset += prompt_len
|
||||
|
||||
return self.forward_all(hidden_states, prompt_lens, instr_lens)
|
||||
return pooled_data
|
||||
|
||||
|
||||
class GritLMPooler(Pooler):
|
||||
|
||||
@@ -777,13 +777,6 @@ def fast_topk(
|
||||
return torch.topk(values, topk, dim=dim)
|
||||
|
||||
|
||||
def get_model_hidden_size(hf_config: PretrainedConfig) -> int:
|
||||
if hasattr(hf_config, "hidden_size"):
|
||||
return hf_config.hidden_size
|
||||
text_config = hf_config.get_text_config()
|
||||
return text_config.hidden_size
|
||||
|
||||
|
||||
# Chunk x along the num_tokens axis for sequence parallelism
|
||||
# NOTE: This is wrapped in a torch custom op to work around the following issue:
|
||||
# The output tensor can have a sequence length 0 at small input sequence lengths
|
||||
|
||||
@@ -1049,6 +1049,40 @@ def try_get_tokenizer_config(
|
||||
return None
|
||||
|
||||
|
||||
@cache
|
||||
def try_get_dense_modules(
|
||||
model: str | Path,
|
||||
revision: str | None = None,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
try:
|
||||
modules = get_hf_file_to_dict("modules.json", model, revision)
|
||||
if not modules:
|
||||
return None
|
||||
|
||||
if isinstance(modules, dict):
|
||||
modules = modules.get("modules", [])
|
||||
|
||||
dense_modules = [
|
||||
m for m in modules if m.get("type") == "sentence_transformers.models.Dense"
|
||||
]
|
||||
if not dense_modules:
|
||||
return None
|
||||
|
||||
layer_configs = []
|
||||
for module in dense_modules:
|
||||
folder = module.get("path", "")
|
||||
|
||||
config_path = f"{folder}/config.json" if folder else "config.json"
|
||||
layer_config = get_hf_file_to_dict(config_path, model, revision)
|
||||
if not layer_config:
|
||||
continue
|
||||
layer_config["folder"] = folder
|
||||
layer_configs.append(layer_config)
|
||||
return layer_configs
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def get_safetensors_params_metadata(
|
||||
model: str,
|
||||
*,
|
||||
|
||||
Reference in New Issue
Block a user