diff --git a/src/devices/cuda/cudadevicebatch.cpp b/src/devices/cuda/cudadevicebatch.cpp index db381f7..7dd1e96 100644 --- a/src/devices/cuda/cudadevicebatch.cpp +++ b/src/devices/cuda/cudadevicebatch.cpp @@ -255,7 +255,7 @@ namespace fastllm { "CatDirect error: inputs should use same device.\n"); AssertInFastLLM(input0s[0]->dims.size() == 0 || input0s[0]->dims.size() == input1s[0]->dims.size(), "Cat Error: input's shape's size should be same.\n"); - int dimsLen = input0s[1]->dims.size(); + int dimsLen = input1s[0]->dims.size(); axis = (axis % dimsLen + dimsLen) % dimsLen; for (int i = 0; i < dimsLen && i < input0s[0]->dims.size(); i++) { if (i != axis) {