mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 15:04:47 +08:00
50 lines
1.8 KiB
Python
50 lines
1.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import torch
|
|
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
|
|
def test_sequence_intermediate_tensors_equal():
|
|
class AnotherIntermediateTensors(IntermediateTensors):
|
|
pass
|
|
|
|
intermediate_tensors = IntermediateTensors({})
|
|
another_intermediate_tensors = AnotherIntermediateTensors({})
|
|
assert intermediate_tensors != another_intermediate_tensors
|
|
|
|
empty_intermediate_tensors_1 = IntermediateTensors({})
|
|
empty_intermediate_tensors_2 = IntermediateTensors({})
|
|
assert empty_intermediate_tensors_1 == empty_intermediate_tensors_2
|
|
|
|
different_key_intermediate_tensors_1 = IntermediateTensors(
|
|
{"1": torch.zeros([2, 4], dtype=torch.int32)}
|
|
)
|
|
difference_key_intermediate_tensors_2 = IntermediateTensors(
|
|
{"2": torch.zeros([2, 4], dtype=torch.int32)}
|
|
)
|
|
assert different_key_intermediate_tensors_1 != difference_key_intermediate_tensors_2
|
|
|
|
same_key_different_value_intermediate_tensors_1 = IntermediateTensors(
|
|
{"1": torch.zeros([2, 4], dtype=torch.int32)}
|
|
)
|
|
same_key_different_value_intermediate_tensors_2 = IntermediateTensors(
|
|
{"1": torch.zeros([2, 5], dtype=torch.int32)}
|
|
)
|
|
assert (
|
|
same_key_different_value_intermediate_tensors_1
|
|
!= same_key_different_value_intermediate_tensors_2
|
|
)
|
|
|
|
same_key_same_value_intermediate_tensors_1 = IntermediateTensors(
|
|
{"1": torch.zeros([2, 4], dtype=torch.int32)}
|
|
)
|
|
same_key_same_value_intermediate_tensors_2 = IntermediateTensors(
|
|
{"1": torch.zeros([2, 4], dtype=torch.int32)}
|
|
)
|
|
assert (
|
|
same_key_same_value_intermediate_tensors_1
|
|
== same_key_same_value_intermediate_tensors_2
|
|
)
|