Skip to content

Commit

Permalink
Add tests for Argchecking
Browse files Browse the repository at this point in the history
  • Loading branch information
maxmynter committed Jan 22, 2024
1 parent f7dc214 commit d40fe07
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@
def testfolder() -> str:
"""A test directory for benchmark collection."""
return str(HERE / "testproject")


@pytest.fixture(scope='session')
def typecheckfolder() -> str:
return str(HERE / "typecheck_tests")
22 changes: 22 additions & 0 deletions tests/test_argcheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
import pytest
from nnbench import runner


def test_argcheck(typecheckfolder: str) -> None:
benchmarks = os.path.join(typecheckfolder, "benchmarks.py")
r = runner.AbstractBenchmarkRunner()
with pytest.raises(TypeError):
r.run(benchmarks, params={"x": 1, "y": "1"})
with pytest.raises(ValueError):
r.run(benchmarks, params={"x": 1})
r.run(benchmarks,
params={"x": 1, "y": 1})


def test_raises_erro_on_duplicate_params(typecheckfolder: str) -> None:
benchmarks = os.path.join(typecheckfolder, "duplicate_benchmarks.py")

with pytest.raises(TypeError):
r = runner.AbstractBenchmarkRunner()
r.run(benchmarks, params={"x": 1, "y": 1})
16 changes: 16 additions & 0 deletions tests/typecheck_tests/benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import nnbench


@nnbench.benchmark
def double(x: int) -> int:
return x*2


@nnbench.benchmark
def triple(y: int) -> int:
return y * 3


@nnbench.benchmark
def prod(x: int, y: int) -> int:
return x * y
11 changes: 11 additions & 0 deletions tests/typecheck_tests/duplicate_benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import nnbench


@nnbench.benchmark
def double(x: int) -> int:
return x*2


@nnbench.benchmark
def triple_str(x: str) -> str:
return x * 3

0 comments on commit d40fe07

Please sign in to comment.