[CPU]Parallelize over tokens in int4 moe (#29600)

Signed-off-by: Zhang Xiangze <Xiangze.Zhang@arm.com>
This commit is contained in:
Zhang Xiangze
2025-12-02 14:21:39 +08:00
committed by GitHub
parent 4b612664fd
commit 13ea39bc09

View File

@@ -93,16 +93,16 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
}
auto Y_all = at::empty({offsets[E], H}, x_c.options());
at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) {
at::parallel_for(0, offsets[E], 0, [&](int64_t idx_begin, int64_t idx_end) {
c10::InferenceMode guard;
for (int64_t e = e_begin; e < e_end; ++e) {
const int64_t te = counts[e];
if (te == 0) {
for (int64_t e = 0; e < E; ++e) {
int64_t start = std::max(offsets[e], idx_begin);
int64_t end = std::min(offsets[e + 1], idx_end);
int64_t te = end - start;
if (te <= 0) {
continue;
}
const int64_t start = offsets[e];
auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
auto w13_e = w13_packed.select(/*dim=*/0, e);