From 4736a6d406212f7c8866b3c65f6ad0498d09f1d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Mon, 23 Sep 2024 11:45:12 +0800 Subject: [PATCH] =?UTF-8?q?fix=20=E7=9B=B4=E6=8E=A5=E8=AF=BBbge-largh-zh?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/fastllm.cpp | 4 +++- src/model.cpp | 3 +++ src/models/bert.cpp | 12 +++++++++++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/fastllm.cpp b/src/fastllm.cpp index edc699cb..01cff726 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -549,7 +549,9 @@ namespace fastllm { data.UpdateUnitSize(); data.Allocate(); if (dataType == oriDataType) { - memcpy(data.cpuData, oriData, data.GetBytes()); + if (oriData != nullptr) { + memcpy(data.cpuData, oriData, data.GetBytes()); + } } else if (oriDataType == DataType::BFLOAT16 && dataType == DataType::FLOAT16) { uint16_t *a = (uint16_t*)data.cpuData; diff --git a/src/model.cpp b/src/model.cpp index ec8f53f4..4b060590 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -287,6 +287,9 @@ namespace fastllm { if (dstType != DataType::FLOAT32) { ErrorInFastLLM("SafeTensorItem.CreateBuffer: unsupport src dtype " + this->dtype + "\n"); } + } else if (this->dtype == "I64") { + printf("skip I64 tensor %s\n", this->tensorName.c_str()); + return; } else { ErrorInFastLLM("SafeTensorItem.CreateBuffer: unsupport src dtype " + this->dtype + "\n"); } diff --git a/src/models/bert.cpp b/src/models/bert.cpp index 45186331..02875a8e 100644 --- a/src/models/bert.cpp +++ b/src/models/bert.cpp @@ -212,7 +212,17 @@ namespace fastllm { void BertModel::WarmUp() { printf("Warmup...\n"); - EmbeddingSentence({"1"}, true); + int batch = 1, len = 1; + std::vector ids = std::vector (batch * len, 0.0f); + std::vector seqLens = std::vector (batch, 0.0f); + std::vector token_type_ids = std::vector (batch * len, 0.0f); + std::vector attention_mask = std::vector (batch * len, -1e10f); + std::vector position_ids = std::vector (batch * len, 0.0f); + fastllm::Data inputIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, ids); + fastllm::Data attentionMask = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, attention_mask); + fastllm::Data tokenTypeIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, token_type_ids); + fastllm::Data positionIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids); + ForwardAll(inputIds, attentionMask, tokenTypeIds, positionIds, true); printf("finish.\n"); }