Skip to content

Commit

Permalink
fix bug in FIL
Browse files Browse the repository at this point in the history
  • Loading branch information
getumen committed Aug 6, 2024
1 parent 95b2ae8 commit 26a2b97
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 17 deletions.
14 changes: 3 additions & 11 deletions src/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,13 @@ namespace
char const *filename,
TreeliteModelHandle *model_handle)
{
std::string json_config = "{\"allow_unknown_field\": True}";
std::string json_config = "{}";
switch (model_type)
{
case ModelType::XGBoost:
return TreeliteLoadXGBoostModel(filename, json_config.c_str(), model_handle);
return TreeliteLoadXGBoostModelLegacyBinary(filename, json_config.c_str(), model_handle);
case ModelType::XGBoostJSON: {
std::ifstream file(filename); // Replace with your file name
if (!file.is_open()) {
return -1;
}
std::string content((std::istreambuf_iterator<char>(file)),
std::istreambuf_iterator<char>());
file.close();

return TreeliteLoadXGBoostModelFromString(content.c_str(), content.length(), json_config.c_str(), model_handle);
return TreeliteLoadXGBoostModel(filename, json_config.c_str(), model_handle);
}
case ModelType::LightGBM:
return TreeliteLoadLightGBMModel(filename, json_config.c_str(), model_handle);
Expand Down
4 changes: 2 additions & 2 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ add_executable(
cuml_test
memory_resource_test.cpp
clustering_test.cpp
# fil_test.cpp
# linear_regression_test.cpp
fil_test.cpp
linear_regression_test.cpp
)

target_compile_options(cuml_test PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda --expt-relaxed-constexpr>)
Expand Down
9 changes: 5 additions & 4 deletions tests/fil_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
#include "cuml4c/memory_resource.h"
#include "cuml4c/fil.h"

TEST(FILTest, TestTreelite)

TEST(FILTest, TestTreeliteJSON)
{
std::string json_config = "{\"allow_unknown_field\": True}";
std::string json_config = "{}";

TreeliteModelHandle handle;
auto res = TreeliteLoadXGBoostModel("testdata/xgboost.json", json_config.c_str(), &handle);
Expand All @@ -25,10 +26,10 @@ TEST(FILTest, TestFIL)
CreateDeviceResourceHandle(&device_resource_handle);

DeviceMemoryResource mr;
UseArenaMemoryResource(&mr, 64 * 1024);
UseArenaMemoryResource(&mr, 1024 * 1024);

FILModelHandle handle;
auto res = FILLoadModel(device_resource_handle, 0, "testdata/xgboost.model", 0, true, 0.5, 0, 0, 1, 0, &handle);
auto res = FILLoadModel(device_resource_handle, 1, "testdata/xgboost.json", 0, true, 0.5, 0, 0, 1, 0, &handle);
EXPECT_EQ(res, 0);

std::vector<float> feature;
Expand Down

0 comments on commit 26a2b97

Please sign in to comment.