From c2fb785e854c348d9db95d6d79ef4b2e15148bad Mon Sep 17 00:00:00 2001 From: cgli Date: Sat, 23 Dec 2023 21:00:20 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E9=9D=9Ebatch=E4=B8=8BCPU=20?= =?UTF-8?q?Attention=E7=AE=97=E5=AD=90=E5=8F=96batch=E9=94=99=E8=AF=AF(#38?= =?UTF-8?q?5)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/devices/cpu/cpudevice.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/devices/cpu/cpudevice.cpp b/src/devices/cpu/cpudevice.cpp index 2915e41d..92fb3366 100644 --- a/src/devices/cpu/cpudevice.cpp +++ b/src/devices/cpu/cpudevice.cpp @@ -282,8 +282,9 @@ namespace fastllm { float *vd = (float*)v.cpuData; float *maskd = (datas.find("mask")->second && mask.dims.size() > 0) ? (float*)mask.cpuData : nullptr; float *od = (float*)output.cpuData; - int batch = intParams.find("q___batch")->second; - int maskStride = (datas.find("mask")->second) ? (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0)) : 0; + int batch = (maskd != nullptr && mask.dims.size() == 3) ? mask.dims[0] : 1; + batch = intParams.find("mask___batch") != intParams.end() ? intParams.find("mask___batch")->second : batch; + int maskStride = (maskd != nullptr) ? (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0)) : 0; std::fill(od, od + output.Count(0), 0.0f); auto pool = GetPool(); std::vector > futures;