Skip to content

Commit

Permalink
Merge pull request #3 from HaoZeke/validateMethods
Browse files Browse the repository at this point in the history
MAINT,TST: Validate methods
  • Loading branch information
HaoZeke authored Aug 8, 2023
2 parents 1c39492 + 81bb33b commit f1b3cba
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 0 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Basic consistency tests
on:
push:
branches:
- main
pull_request:
branches:
- main

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
buildmamba:
runs-on: ${{ matrix.config.os }}
name: ${{ matrix.config.os }}
strategy:
fail-fast: false
matrix:
config:
- {os: macOS-latest}
- {os: ubuntu-latest}
steps:
- uses: actions/checkout@v3
with:
submodules: "recursive"
fetch-depth: 0
- name: Install Conda environment
uses: mamba-org/setup-micromamba@v1
with:
environment-file: environment.yml
cache-environment: true
- name: Build and Test
shell: bash -l {0}
run: |
fpm build
pytest -vvv
3 changes: 3 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ dependencies:
# Python, Sympy
- ipython
- sympy
- numpy
- jupytext
- papermill
- ipykernel
# More testing
- pytest
70 changes: 70 additions & 0 deletions pytests/test_gjp_rec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import subprocess

import numpy as np
import pytest


def extract_values(output_str):
lines = output_str.strip().split("\n")
roots = []
weights = []
for line in lines:
_, root, _, weight = line.split()
roots.append(float(root))
weights.append(float(weight))
return np.array(roots), np.array(weights)


def run_fortran(n, alpha, beta):
result = subprocess.run(
[
"fpm",
"run",
"gjp_quad_rec",
"--",
str(n),
"{:.1f}".format(alpha),
"{:.1f}".format(beta),
],
stdout=subprocess.PIPE,
)
return extract_values(result.stdout.decode())


def run_python(n, alpha, beta):
result = subprocess.run(
[
"python",
"scripts/gen_analytic_vals.py",
"--npts",
str(n),
"--alpha",
str(alpha),
"--beta",
str(beta),
],
stdout=subprocess.PIPE,
)
return extract_values(result.stdout.decode())


@pytest.mark.parametrize(
"n, alpha, beta",
[
(3, 1, 5),
(5, 2, 3),
pytest.param(
10,
0.0,
30.0,
marks=pytest.mark.xfail(reason="High beta values diverge"),
id="highbeta_fail",
),
],
)
def test_gjp_quad_rec(n, alpha, beta):
fortran_roots, fortran_weights = run_fortran(n, alpha, beta)
python_roots, python_weights = run_python(n, alpha, beta)

assert np.allclose(fortran_roots, python_roots, atol=1e-14)
assert np.allclose(fortran_weights, python_weights, atol=1e-14)

0 comments on commit f1b3cba

Please sign in to comment.