diff --git a/src/devices/cpu/cpudevice.cpp b/src/devices/cpu/cpudevice.cpp index 2915e41d..92fb3366 100644 --- a/src/devices/cpu/cpudevice.cpp +++ b/src/devices/cpu/cpudevice.cpp @@ -282,8 +282,9 @@ namespace fastllm { float *vd = (float*)v.cpuData; float *maskd = (datas.find("mask")->second && mask.dims.size() > 0) ? (float*)mask.cpuData : nullptr; float *od = (float*)output.cpuData; - int batch = intParams.find("q___batch")->second; - int maskStride = (datas.find("mask")->second) ? (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0)) : 0; + int batch = (maskd != nullptr && mask.dims.size() == 3) ? mask.dims[0] : 1; + batch = intParams.find("mask___batch") != intParams.end() ? intParams.find("mask___batch")->second : batch; + int maskStride = (maskd != nullptr) ? (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0)) : 0; std::fill(od, od + output.Count(0), 0.0f); auto pool = GetPool(); std::vector > futures;