Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled free-threading multithread tests and added more skips #203

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,10 @@ jobs:
- name: Build wheels
run: python -m cibuildwheel --output-dir wheelhouse
env:
# TODO(jakevdp): re-add 313t & free-threading support
CIBW_ARCHS_LINUX: auto aarch64
CIBW_ARCHS_MACOS: universal2
CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* cp313-* # cp313t-*
# CIBW_FREE_THREADED_SUPPORT: True
CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* cp313*
CIBW_FREE_THREADED_SUPPORT: True
CIBW_PRERELEASE_PYTHONS: True
CIBW_SKIP: "*musllinux* *i686* *win32* *t-win*"
CIBW_TEST_REQUIRES: absl-py pytest pytest-xdist
Expand Down
48 changes: 28 additions & 20 deletions ml_dtypes/tests/custom_float_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
# from multi_thread_utils import multi_threaded
from multi_thread_utils import multi_threaded
import numpy as np

bfloat16 = ml_dtypes.bfloat16
Expand Down Expand Up @@ -221,12 +221,16 @@ def dtype_is_signed(dtype):
}


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# pylint: disable=g-complex-comprehension
# @multi_threaded(
# num_workers=3,
# skip_tests=["testDiv", "testRoundTripNumpyTypes", "testRoundTripToNumpy"],
# )
@multi_threaded(
num_workers=3,
skip_tests=[
"testDiv",
"testPickleable",
"testRoundTripNumpyTypes",
"testRoundTripToNumpy",
],
)
@parameterized.named_parameters(
(
{"testcase_name": "_" + dtype.__name__, "float_type": dtype}
Expand Down Expand Up @@ -661,21 +665,25 @@ def testDtypeFromString(self, float_type):
]


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# pylint: disable=g-complex-comprehension
# @multi_threaded(
# num_workers=3,
# skip_tests=[
# "testBinaryUfunc",
# "testConformNumpyComplex",
# "testFloordivCornerCases",
# "testDivmodCornerCases",
# "testSpacing",
# "testUnaryUfunc",
# "testCasts",
# "testLdexp",
# ],
# )
@multi_threaded(
num_workers=3,
skip_tests=[
"testBinaryPredicateUfunc",
"testBinaryUfunc",
"testCasts",
"testConformNumpyComplex",
"testDivmod",
"testDivmodCornerCases",
"testFloordivCornerCases",
"testFrexp",
"testLdexp",
"testModf",
"testPredicateUfunc",
"testSpacing",
"testUnaryUfunc",
],
)
@parameterized.named_parameters(
(
{"testcase_name": "_" + dtype.__name__, "float_type": dtype}
Expand Down
5 changes: 2 additions & 3 deletions ml_dtypes/tests/finfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
# from multi_thread_utils import multi_threaded
from multi_thread_utils import multi_threaded
import numpy as np

ALL_DTYPES = [
Expand Down Expand Up @@ -55,8 +55,7 @@
}


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# @multi_threaded(num_workers=3)
@multi_threaded(num_workers=3)
class FinfoTest(parameterized.TestCase):

def assertNanEqual(self, x, y):
Expand Down
5 changes: 2 additions & 3 deletions ml_dtypes/tests/iinfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
# from multi_thread_utils import multi_threaded
from multi_thread_utils import multi_threaded
import numpy as np


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# @multi_threaded(num_workers=3)
@multi_threaded(num_workers=3)
class IinfoTest(parameterized.TestCase):

def testIinfoInt2(self):
Expand Down
8 changes: 3 additions & 5 deletions ml_dtypes/tests/intn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
# from multi_thread_utils import multi_threaded
from multi_thread_utils import multi_threaded
import numpy as np

int2 = ml_dtypes.int2
Expand All @@ -48,9 +48,8 @@ def ignore_warning(**kw):
yield


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# Tests for the Python scalar type
# @multi_threaded(num_workers=3)
@multi_threaded(num_workers=3)
class ScalarTest(parameterized.TestCase):

@parameterized.product(scalar_type=INTN_TYPES)
Expand Down Expand Up @@ -247,9 +246,8 @@ def testCanCast(self, a, b):
)


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# Tests for the Python scalar type
# @multi_threaded(num_workers=3, skip_tests=["testBinaryUfuncs"])
@multi_threaded(num_workers=3, skip_tests=["testBinaryUfuncs"])
class ArrayTest(parameterized.TestCase):

@parameterized.product(scalar_type=INTN_TYPES)
Expand Down
5 changes: 2 additions & 3 deletions ml_dtypes/tests/metadata_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@

from absl.testing import absltest
import ml_dtypes
# from multi_thread_utils import multi_threaded
from multi_thread_utils import multi_threaded


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# @multi_threaded(num_workers=3)
@multi_threaded(num_workers=3)
class CustomFloatTest(absltest.TestCase):

def test_version_matches_package_metadata(self):
Expand Down
Loading