diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 85c73462..08998739 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -45,11 +45,10 @@ jobs: - name: Build wheels run: python -m cibuildwheel --output-dir wheelhouse env: - # TODO(jakevdp): re-add 313t & free-threading support CIBW_ARCHS_LINUX: auto aarch64 CIBW_ARCHS_MACOS: universal2 - CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* cp313-* # cp313t-* - # CIBW_FREE_THREADED_SUPPORT: True + CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* cp313* + CIBW_FREE_THREADED_SUPPORT: True CIBW_PRERELEASE_PYTHONS: True CIBW_SKIP: "*musllinux* *i686* *win32* *t-win*" CIBW_TEST_REQUIRES: absl-py pytest pytest-xdist diff --git a/ml_dtypes/tests/custom_float_test.py b/ml_dtypes/tests/custom_float_test.py index 94dc5f53..7eb313bb 100644 --- a/ml_dtypes/tests/custom_float_test.py +++ b/ml_dtypes/tests/custom_float_test.py @@ -27,7 +27,7 @@ from absl.testing import absltest from absl.testing import parameterized import ml_dtypes -# from multi_thread_utils import multi_threaded +from multi_thread_utils import multi_threaded import numpy as np bfloat16 = ml_dtypes.bfloat16 @@ -221,12 +221,16 @@ def dtype_is_signed(dtype): } -# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release. # pylint: disable=g-complex-comprehension -# @multi_threaded( -# num_workers=3, -# skip_tests=["testDiv", "testRoundTripNumpyTypes", "testRoundTripToNumpy"], -# ) +@multi_threaded( + num_workers=3, + skip_tests=[ + "testDiv", + "testPickleable", + "testRoundTripNumpyTypes", + "testRoundTripToNumpy", + ], +) @parameterized.named_parameters( ( {"testcase_name": "_" + dtype.__name__, "float_type": dtype} @@ -661,21 +665,25 @@ def testDtypeFromString(self, float_type): ] -# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release. # pylint: disable=g-complex-comprehension -# @multi_threaded( -# num_workers=3, -# skip_tests=[ -# "testBinaryUfunc", -# "testConformNumpyComplex", -# "testFloordivCornerCases", -# "testDivmodCornerCases", -# "testSpacing", -# "testUnaryUfunc", -# "testCasts", -# "testLdexp", -# ], -# ) +@multi_threaded( + num_workers=3, + skip_tests=[ + "testBinaryPredicateUfunc", + "testBinaryUfunc", + "testCasts", + "testConformNumpyComplex", + "testDivmod", + "testDivmodCornerCases", + "testFloordivCornerCases", + "testFrexp", + "testLdexp", + "testModf", + "testPredicateUfunc", + "testSpacing", + "testUnaryUfunc", + ], +) @parameterized.named_parameters( ( {"testcase_name": "_" + dtype.__name__, "float_type": dtype} diff --git a/ml_dtypes/tests/finfo_test.py b/ml_dtypes/tests/finfo_test.py index 7807fd75..d15311c1 100644 --- a/ml_dtypes/tests/finfo_test.py +++ b/ml_dtypes/tests/finfo_test.py @@ -15,7 +15,7 @@ from absl.testing import absltest from absl.testing import parameterized import ml_dtypes -# from multi_thread_utils import multi_threaded +from multi_thread_utils import multi_threaded import numpy as np ALL_DTYPES = [ @@ -55,8 +55,7 @@ } -# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release. -# @multi_threaded(num_workers=3) +@multi_threaded(num_workers=3) class FinfoTest(parameterized.TestCase): def assertNanEqual(self, x, y): diff --git a/ml_dtypes/tests/iinfo_test.py b/ml_dtypes/tests/iinfo_test.py index 4f15446f..8936c523 100644 --- a/ml_dtypes/tests/iinfo_test.py +++ b/ml_dtypes/tests/iinfo_test.py @@ -15,12 +15,11 @@ from absl.testing import absltest from absl.testing import parameterized import ml_dtypes -# from multi_thread_utils import multi_threaded +from multi_thread_utils import multi_threaded import numpy as np -# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release. -# @multi_threaded(num_workers=3) +@multi_threaded(num_workers=3) class IinfoTest(parameterized.TestCase): def testIinfoInt2(self): diff --git a/ml_dtypes/tests/intn_test.py b/ml_dtypes/tests/intn_test.py index 47e9688b..86ab5a81 100644 --- a/ml_dtypes/tests/intn_test.py +++ b/ml_dtypes/tests/intn_test.py @@ -23,7 +23,7 @@ from absl.testing import absltest from absl.testing import parameterized import ml_dtypes -# from multi_thread_utils import multi_threaded +from multi_thread_utils import multi_threaded import numpy as np int2 = ml_dtypes.int2 @@ -48,9 +48,8 @@ def ignore_warning(**kw): yield -# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release. # Tests for the Python scalar type -# @multi_threaded(num_workers=3) +@multi_threaded(num_workers=3) class ScalarTest(parameterized.TestCase): @parameterized.product(scalar_type=INTN_TYPES) @@ -247,9 +246,8 @@ def testCanCast(self, a, b): ) -# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release. # Tests for the Python scalar type -# @multi_threaded(num_workers=3, skip_tests=["testBinaryUfuncs"]) +@multi_threaded(num_workers=3, skip_tests=["testBinaryUfuncs"]) class ArrayTest(parameterized.TestCase): @parameterized.product(scalar_type=INTN_TYPES) diff --git a/ml_dtypes/tests/metadata_test.py b/ml_dtypes/tests/metadata_test.py index 99abe919..81da5367 100644 --- a/ml_dtypes/tests/metadata_test.py +++ b/ml_dtypes/tests/metadata_test.py @@ -16,11 +16,10 @@ from absl.testing import absltest import ml_dtypes -# from multi_thread_utils import multi_threaded +from multi_thread_utils import multi_threaded -# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release. -# @multi_threaded(num_workers=3) +@multi_threaded(num_workers=3) class CustomFloatTest(absltest.TestCase): def test_version_matches_package_metadata(self):