mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
[CPU]Parallelize over tokens in int4 moe (#29600)
Signed-off-by: Zhang Xiangze <Xiangze.Zhang@arm.com>
This commit is contained in:
@@ -93,16 +93,16 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||
}
|
||||
auto Y_all = at::empty({offsets[E], H}, x_c.options());
|
||||
|
||||
at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) {
|
||||
at::parallel_for(0, offsets[E], 0, [&](int64_t idx_begin, int64_t idx_end) {
|
||||
c10::InferenceMode guard;
|
||||
for (int64_t e = e_begin; e < e_end; ++e) {
|
||||
const int64_t te = counts[e];
|
||||
if (te == 0) {
|
||||
for (int64_t e = 0; e < E; ++e) {
|
||||
int64_t start = std::max(offsets[e], idx_begin);
|
||||
int64_t end = std::min(offsets[e + 1], idx_end);
|
||||
int64_t te = end - start;
|
||||
if (te <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t start = offsets[e];
|
||||
|
||||
auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
|
||||
|
||||
auto w13_e = w13_packed.select(/*dim=*/0, e);
|
||||
|
||||
Reference in New Issue
Block a user