Skip to content

Commit

Permalink
user SerializeToBuffer for internal serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes committed Jan 10, 2024
1 parent 692e5a5 commit c161999
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions R-package/src/xgboost_R.cc
Original file line number Diff line number Diff line change
Expand Up @@ -737,15 +737,13 @@ void XGBAltrepSetPointer(SEXP R_altrepped_obj, BoosterHandle handle) {
R_RegisterCFinalizerEx(R_ptr, _BoosterFinalizer, TRUE);
}

const char *ubj_json_format_str = "{\"format\": \"ubj\"}";

SEXP XGBAltrepSerializer_R(SEXP R_altrepped_obj) {
R_API_BEGIN();
BoosterHandle handle = R_ExternalPtrAddr(R_altrep_data1(R_altrepped_obj));
char const *serialized_bytes;
bst_ulong serialized_length;
CHECK_CALL(XGBoosterSaveModelToBuffer(
handle, ubj_json_format_str, &serialized_length, &serialized_bytes));
CHECK_CALL(XGBoosterSerializeToBuffer(
handle, &serialized_length, &serialized_bytes));
SEXP R_state = Rf_protect(Rf_allocVector(RAWSXP, serialized_length));
if (serialized_length != 0) {
std::memcpy(RAW(R_state), serialized_bytes, serialized_length);
Expand All @@ -761,9 +759,9 @@ SEXP XGBAltrepDeserializer_R(SEXP unused, SEXP R_state) {
R_API_BEGIN();
BoosterHandle handle = nullptr;
CHECK_CALL(XGBoosterCreate(nullptr, 0, &handle));
int res_code = XGBoosterLoadModelFromBuffer(handle,
RAW(R_state),
Rf_xlength(R_state));
int res_code = XGBoosterUnserializeFromBuffer(handle,
RAW(R_state),
Rf_xlength(R_state));
if (res_code != 0) {
XGBoosterFree(handle);
}
Expand Down Expand Up @@ -794,14 +792,14 @@ SEXP XGBAltrepDuplicate_R(SEXP R_altrepped_obj, Rboolean deep) {
SEXP out = Rf_protect(XGBMakeEmptyAltrep());
char const *serialized_bytes;
bst_ulong serialized_length;
CHECK_CALL(XGBoosterSaveModelToBuffer(
CHECK_CALL(XGBoosterSerializeToBuffer(
R_ExternalPtrAddr(R_altrep_data1(R_altrepped_obj)),
ubj_json_format_str, &serialized_length, &serialized_bytes));
&serialized_length, &serialized_bytes));
BoosterHandle new_handle = nullptr;
CHECK_CALL(XGBoosterCreate(nullptr, 0, &new_handle));
int res_code = XGBoosterLoadModelFromBuffer(new_handle,
serialized_bytes,
serialized_length);
int res_code = XGBoosterUnserializeFromBuffer(new_handle,
serialized_bytes,
serialized_length);
if (res_code != 0) {
XGBoosterFree(new_handle);
}
Expand Down

0 comments on commit c161999

Please sign in to comment.