diff --git a/src/models/minicpm3.cpp b/src/models/minicpm3.cpp index dbd2337..d3e0940 100644 --- a/src/models/minicpm3.cpp +++ b/src/models/minicpm3.cpp @@ -277,8 +277,198 @@ namespace fastllm { const std::vector &generationConfigs, const LastTokensManager &lastTokens, std::vector *> *retLogits) { - AssertInFastLLM(false, "MiniCpm3Model::ForwardBatch todo"); - return {0}; + + int v_head_dim = this->hidden_size / this->num_attention_heads; + Data hiddenStates; + Data attenInput; + Data qa, qa_norm, qb, batch_q_nope, batch_q_rope; + Data kva, compressed_kv, batch_k_rope, kv_norm, kvb; + Data batch_k_nope, k_rope_expand, batch_value_states, query_states, key_states; + Data attenWeights, curAttenOutput, attenLastOutput; + Data w1, w2, w3; + + Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates); + Mul(hiddenStates, embed_scale, hiddenStates); + for (int i = 0; i < block_cnt; i++) { + ApplyDeviceMap(this->deviceMap, i + 1, block_cnt); + RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".input_layernorm.weight"], + 1e-5, attenInput); + std::string qaWeightName = "model.layers." + std::to_string(i) + ".self_attn.q_a_proj.weight"; + std::string qbWeightName = "model.layers." + std::to_string(i) + ".self_attn.q_b_proj.weight"; + std::string kvaWeightName = "model.layers." + std::to_string(i) + ".self_attn.kv_a_proj_with_mqa.weight"; + std::string kvbWeightName = "model.layers." + std::to_string(i) + ".self_attn.kv_b_proj.weight"; + std::string oWeightName = "model.layers." + std::to_string(i) + ".self_attn.o_proj.weight"; + + // 1.1 Get q, k, v + int bsz = attenInput.dims[0], b_seqlen = attenInput.dims[1]; + Linear(attenInput, weight[qaWeightName], Data(), qa); + RMSNorm(qa, this->weight["model.layers." + std::to_string(i) + ".self_attn.q_a_layernorm.weight"], + 1e-5, qa_norm); + Linear(qa_norm, weight[qbWeightName], Data(), qb); + qb.Reshape({bsz, b_seqlen, num_attention_heads, -1}); + PermuteSelf(qb, {0, 2, 1, 3}); + Split(qb, -1, 0, this->qk_nope_head_dim, batch_q_nope); + Split(qb, -1, this->qk_nope_head_dim, this->qk_nope_head_dim + this->qk_rope_head_dim, batch_q_rope); + + Linear(attenInput, weight[kvaWeightName], Data(), kva); + Split(kva, -1, 0, this->kv_lora_rank, compressed_kv); + Split(kva, -1, this->kv_lora_rank, this->kv_lora_rank + this->qk_rope_head_dim, batch_k_rope); + batch_k_rope.Reshape({bsz, 1, b_seqlen, this->qk_rope_head_dim}); + RMSNorm(compressed_kv, this->weight["model.layers." + std::to_string(i) + ".self_attn.kv_a_layernorm.weight"], + 1e-5, kv_norm); + Linear(kv_norm, weight[kvbWeightName], Data(), kvb); + kvb.Reshape({bsz, b_seqlen, num_attention_heads, qk_nope_head_dim + v_head_dim}); + PermuteSelf(kvb, {0, 2, 1, 3}); + Split(kvb, -1, 0, qk_nope_head_dim, batch_k_nope); + Split(kvb, -1, qk_nope_head_dim, qk_nope_head_dim + v_head_dim, batch_value_states); + + Data attenOutput = Data(DataType::FLOAT32); + int total = 0; + std::vector curQNs, curQRs, curKNs, curKRs, curVs; + curQNs.resize(batch); + curQRs.resize(batch); + curKNs.resize(batch); + curKRs.resize(batch); + curVs.resize(batch); + for (int b = 0; b < batch; b++) { + Split(batch_q_nope, 2, total, total + seqLens[b], curQNs[b]); + Split(batch_q_rope, 2, total, total + seqLens[b], curQRs[b]); + Split(batch_k_nope, 2, total, total + seqLens[b], curKNs[b]); + Split(batch_k_rope, 2, total, total + seqLens[b], curKRs[b]); + Split(batch_value_states, 2, total, total + seqLens[b], curVs[b]); + total += seqLens[b]; + } + + for (int b = 0; b < batch; b++) { + int seqlen = seqLens[b]; + auto &q_nope = curQNs[b], &q_rope = curQRs[b]; + auto &k_nope = curKNs[b], &k_rope = curKRs[b], &value_states = curVs[b]; + + PermuteSelf(q_rope, {0, 2, 1, 3}); + PermuteSelf(k_rope, {0, 2, 1, 3}); + fastllm::LlamaRotatePosition2D(q_rope, *positionIds[b], sinData, cosData, rotary_dim); + fastllm::LlamaRotatePosition2D(k_rope, *positionIds[b], sinData, cosData, rotary_dim); + PermuteSelf(q_rope, {0, 2, 1, 3}); + PermuteSelf(k_rope, {0, 2, 1, 3}); + Cat(q_nope, q_rope, -1, query_states); + + k_rope.Reshape({bsz, seqlen * qk_rope_head_dim}); + k_rope_expand.ToDevice(DataDevice::CUDA); + k_rope_expand.CopyFrom(k_rope); + k_rope_expand.Expansion({bsz, num_attention_heads * seqlen * qk_rope_head_dim}); + for (int i = 1; i < num_attention_heads; i++) + CatDirect(k_rope_expand, k_rope, 1); + k_rope_expand.expansionDims.clear(); + k_rope_expand.Reshape({bsz, num_attention_heads, seqlen, qk_rope_head_dim}); + Cat(k_nope, k_rope_expand, -1, key_states); + + Data &pastKey = *pastKeyValues[b * block_cnt + i].first, &pastValue = *pastKeyValues[b * block_cnt + i].second; + if (GetKVCacheInCPU()) { + pastKey.lockInCPU = true; + pastValue.lockInCPU = true; + } else { + pastKey.ToDevice(DataDevice::CUDA); + pastValue.ToDevice(DataDevice::CUDA); + } + key_states.Reshape({bsz * num_attention_heads, seqlen, -1}); + value_states.Reshape({bsz * num_attention_heads, seqlen, -1}); + + int key_unitLen = 96; + #ifdef USE_CUDA + key_unitLen = 192; + #endif + while ((pastKey.dims.size() == 0 && (pastKey.expansionDims.size() == 0 || key_states.dims[1] > pastKey.expansionDims[1])) + || (pastKey.dims.size() > 0 && pastKey.dims[1] + key_states.dims[1] > pastKey.expansionDims[1])) { + std::vector newDims; + if (pastKey.Count(0) == 0 || pastKey.dims.size() == 0) { + newDims = std::vector {key_states.dims[0], ((key_states.dims[1] - 1) / key_unitLen + 1) * key_unitLen, key_states.dims[2]}; + } else { + newDims = pastKey.dims; + newDims[1] += ((key_states.dims[1] - 1) / key_unitLen + 1) * key_unitLen; + } + pastKey.Expansion(newDims); + } + int value_unitLen = 64; + #ifdef USE_CUDA + value_unitLen = 128; + #endif + while ((pastValue.dims.size() == 0 && (pastValue.expansionDims.size() == 0 || value_states.dims[1] > pastValue.expansionDims[1])) + || (pastValue.dims.size() > 0 && pastValue.dims[1] + value_states.dims[1] > pastValue.expansionDims[1])) { + std::vector newDims; + if (pastValue.Count(0) == 0 || pastValue.dims.size() == 0) { + newDims = std::vector {value_states.dims[0], ((value_states.dims[1] - 1) / value_unitLen + 1) * value_unitLen, value_states.dims[2]}; + } else { + newDims = pastValue.dims; + newDims[1] += ((value_states.dims[1] - 1) / value_unitLen + 1) * value_unitLen; + } + pastValue.Expansion(newDims); + } + CatDirect(pastKey, key_states, 1); + CatDirect(pastValue, value_states, 1); + + // 1.2 Attention + // 1.2.0 q * k^T + query_states.Reshape({bsz * num_attention_heads, seqlen, -1}); + MatMulTransB(query_states, pastKey, attenWeights, 1.0 / sqrt(v_head_dim)); + attenWeights.Reshape({1, attenWeights.dims[0], attenWeights.dims[1], attenWeights.dims[2]}); + if (seqlen > 1) { + int promptLen = pastKey.dims[1]; + std::vector vmask = std::vector (seqlen * promptLen, 0); + for (int i = 0; i < seqlen; i++) + for (int j = i + 1; j < seqlen; j++) + vmask[i * promptLen + (promptLen - seqlen + j)] = 1; + AttentionMask(attenWeights, Data(DataType::FLOAT32, {seqlen, promptLen}, vmask), -10000); + } + Softmax(attenWeights, attenWeights, -1); + MatMul(attenWeights, pastValue, curAttenOutput); + curAttenOutput.Reshape({bsz, num_attention_heads, seqlen, v_head_dim}); + PermuteSelf(curAttenOutput, {0, 2, 1, 3}); + curAttenOutput.Reshape({bsz, seqlen, num_attention_heads * v_head_dim}); + if (attenOutput.dims.size() == 0) { + std::vector dims = curAttenOutput.dims; + dims[1] = total; + attenOutput.Expansion(dims); + } + CatDirect(attenOutput, curAttenOutput, 1); + } + Linear(attenOutput, weight[oWeightName], Data(), attenLastOutput); + AddTo(hiddenStates, attenLastOutput, this->attention_scale); + + // 2. mlp + RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], 1e-5, attenInput); + Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), w1); + Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.up_proj.weight"], Data(), w3); + Silu(w1, w1); + MulTo(w1, w3); + Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), w2); + + AddTo(hiddenStates, w2, this->attention_scale); + } + + Data logits, curLogit; + RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-5, hiddenStates); + Mul(hiddenStates, this->rms_scale, hiddenStates); + Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); + std::vector lastRet; + int total = 0; + for (int b = 0; b < batch; b++) { + Split(logits, 1, total + seqLens[b] - 1, total + seqLens[b], curLogit); + if (generationConfigs[b].output_logits && retLogits != nullptr && (*retLogits)[b] != nullptr) { + curLogit.ToDevice(DataDevice::CPU); + (*retLogits)[b]->resize(curLogit.Count(0)); + memcpy((float*)(*retLogits)[b]->data(), (float*)curLogit.cpuData, curLogit.GetBytes()); + } + if (generationConfigs[b].IsSimpleGreedy()) { + Data topk; + TopK(curLogit, topk, 1); + topk.ToDevice(DataDevice::CPU); + lastRet.push_back((int) (((float *) topk.cpuData)[0] + 1e-3)); + } else { + lastRet.push_back(LLMSampling(curLogit, 0, generationConfigs[b], lastTokens.units[b])); + } + total += seqLens[b]; + } + return lastRet; } }