Skip to content

Commit

Permalink
parse versions.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 30, 2024
1 parent a2b61a7 commit 8e04e41
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 13 deletions.
32 changes: 22 additions & 10 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,27 @@ def _register_log_callback(lib: ctypes.CDLL) -> None:
raise XGBoostError(lib.XGBGetLastError())


def _parse_version(ver: str) -> Tuple[Tuple[int, int, int], str]:
"""Avoid dependency on packaging (PEP 440)."""
# 2.0.0-dev, 2.0.0, 2.0.0.post1, or 2.0.0rc1
if ver.find("post") != -1:
major, minor, patch = ver.split(".")[:-1]
postfix = ver.split(".")[-1]
elif "-dev" in ver:
major, minor, patch = ver.split("-")[0].split(".")
postfix = "dev"
else:
major, minor, patch = ver.split(".")
rc = patch.find("rc")
if rc != -1:
postfix = patch[rc:]
patch = patch[:rc]
else:
postfix = ""

return (int(major), int(minor), int(patch)), postfix


def _load_lib() -> ctypes.CDLL:
"""Load xgboost Library."""
lib_paths = find_lib_path()
Expand Down Expand Up @@ -237,17 +258,8 @@ def _load_lib() -> ctypes.CDLL:
)
_register_log_callback(lib)

def parse(ver: str) -> Tuple[int, int, int]:
"""Avoid dependency on packaging (PEP 440)."""
# 2.0.0-dev, 2.0.0, or 2.0.0rc1
major, minor, patch = ver.split("-")[0].split(".")
rc = patch.find("rc")
if rc != -1:
patch = patch[:rc]
return int(major), int(minor), int(patch)

libver = _lib_version(lib)
pyver = parse(_py_version())
pyver, _ = _parse_version(_py_version())

# verify that we are loading the correct binary.
if pyver != libver:
Expand Down
14 changes: 11 additions & 3 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
XGBoostError,
_deprecate_positional_args,
_parse_eval_str,
_parse_version,
_py_version,
)
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_alike, _is_pandas_df
from .training import train
Expand Down Expand Up @@ -801,10 +803,16 @@ def _doc_link_module(self) -> str:

@property
def _doc_link_template(self) -> str:
from .core import _py_version

ver = _py_version()
rel = "latest" if ver.endswith("-dev") else "stable"
(major, minor, patch), post = _parse_version(ver)

if post == "dev":
rel = "latest"
else:
# RTD tracks the release branch, we don't have different branch patch
# release.
rel = f"{major}.{minor}.0"

module = self.__class__.__module__
# All sklearn estimators are forwarded to the top level module in both source
# code and sphinx api doc.
Expand Down
12 changes: 12 additions & 0 deletions tests/python/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import xgboost as xgb
from xgboost import testing as tm
from xgboost.core import _parse_version

dpath = "demo/data/"
rng = np.random.RandomState(1994)
Expand Down Expand Up @@ -315,3 +316,14 @@ def test_Booster_init_invalid_path(self):
"""An invalid model_file path should raise XGBoostError."""
with pytest.raises(xgb.core.XGBoostError):
xgb.Booster(model_file=Path("invalidpath"))


def test_parse_ver() -> None:
(major, minor, patch), post = _parse_version("2.1.0")
assert post == ""
(major, minor, patch), post = _parse_version("2.1.0-dev")
assert post == "dev"
(major, minor, patch), post = _parse_version("2.1.0rc1")
assert post == "rc1"
(major, minor, patch), post = _parse_version("2.1.0.post1")
assert post == "post1"

0 comments on commit 8e04e41

Please sign in to comment.