From fa5bafe617793ed340303cf0ebded6ac03cab39f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 21 Nov 2024 11:56:41 +0000 Subject: [PATCH] debug windows --- .github/workflows/tests.yml | 6 +++--- Makefile | 2 +- tests/test_callbacks.py | 2 ++ 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5d58726d30..38fc80a2ad 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -56,7 +56,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install ".[dev]" + python -m pip install ".[mergekit]" - name: Test with pytest run: | make test @@ -88,7 +88,7 @@ jobs: python -m pip install -U git+https://github.com/huggingface/accelerate.git python -m pip install -U git+https://github.com/huggingface/datasets.git python -m pip install -U git+https://github.com/huggingface/transformers.git - python -m pip install ".[dev]" + python -m pip install ".[mergekit]" - name: Test with pytest run: | make test @@ -149,7 +149,7 @@ jobs: python -m pip install accelerate==0.34.0 python -m pip install datasets==2.21.0 python -m pip install transformers==4.46.0 - python -m pip install ".[dev]" + python -m pip install ".[mergekit]" - name: Test with pytest run: | make test diff --git a/Makefile b/Makefile index 704cacbff2..25c23fb83e 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ dev: ln -s `pwd`/examples/scripts/ `pwd`/trl/commands test: - python -m pytest -n auto --dist=loadfile -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' ./tests/ + python -m pytest tests/test_callbacks.py::MergeModelCallbackTester::test_callback precommit: pre-commit run --all-files diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 0f03cb6a91..df5e15f80a 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -299,6 +299,8 @@ def test_callback(self): trainer.train() last_checkpoint = get_last_checkpoint(tmp_dir) merged_path = os.path.join(last_checkpoint, "merged") + import warnings + warnings.warn(f"merged_path: {merged_path}") self.assertTrue(os.path.isdir(merged_path), "Merged folder does not exist in the last checkpoint.") def test_every_checkpoint(self):