Skip to content

Commit

Permalink
lint.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Nov 14, 2024
1 parent 25c43c9 commit 764a660
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
15 changes: 7 additions & 8 deletions python-package/xgboost/dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,13 @@ def get_address_from_user(
for k in dconfig:
if k not in valid_config:
raise ValueError(f"Unknown configuration: {k}")
else:
warnings.warn(
(
"Use `coll_cfg` instead of the Dask global configuration store"
f" for the XGBoost tracker configuration: {k}."
),
FutureWarning,
)
warnings.warn(
(
"Use `coll_cfg` instead of the Dask global configuration store"
f" for the XGBoost tracker configuration: {k}."
),
FutureWarning,
)
else:
dconfig = {}

Expand Down
8 changes: 6 additions & 2 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,9 +1043,13 @@ def _get_tracker_args(self) -> Tuple[bool, Dict[str, Any]]:
tracker_port = self.getOrDefault(self.tracker_port)

num_workers = self.getOrDefault(self.num_workers)
rabit_args.update(_get_rabit_args(tracker_host_ip, num_workers, tracker_port))
rabit_args.update(
_get_rabit_args(tracker_host_ip, num_workers, tracker_port)
)
else:
if self.isDefined(self.tracker_host_ip) or self.isDefined(self.tracker_port):
if self.isDefined(self.tracker_host_ip) or self.isDefined(
self.tracker_port
):
raise ValueError(
"You must enable launch_tracker_on_driver to use "
"tracker_host_ip and tracker_port"
Expand Down
4 changes: 3 additions & 1 deletion tests/test_distributed/test_with_spark/test_spark_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1657,7 +1657,9 @@ def test_tracker(self):
classifier._get_tracker_args()

classifier = SparkXGBClassifier(
launch_tracker_on_driver=False, tracker_host_ip="127.0.0.1", tracker_port=58892
launch_tracker_on_driver=False,
tracker_host_ip="127.0.0.1",
tracker_port=58892,
)
with pytest.raises(
ValueError, match="You must enable launch_tracker_on_driver"
Expand Down

0 comments on commit 764a660

Please sign in to comment.