Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check Table in angular_separation functions #191

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pyirf/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,5 @@ class WrongColumnUnit(IRFException):

def __init__(self, column, unit, expected):
super().__init__(
f'Unit {unit} of column "{column}"'
f' has incompatible unit "{unit}", expected {expected}'
f" required column {column}"
f'Column "{column}" has incompatible unit "{unit}", expected "{expected}".'
)
118 changes: 102 additions & 16 deletions pyirf/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
import astropy.units as u
from astropy.table import QTable
import numpy as np
import pytest
from astropy.table import QTable, Table

from pyirf.exceptions import MissingColumns, WrongColumnUnit


def test_is_scalar():
Expand Down Expand Up @@ -31,21 +33,105 @@ def test_cone_solid_angle():


def test_check_table():
from pyirf.exceptions import MissingColumns, WrongColumnUnit
from pyirf.utils import check_table

t = QTable({'bar': [0, 1, 2] * u.TeV})

with pytest.raises(MissingColumns):
check_table(t, required_columns=['foo'])

t = QTable({'bar': [0, 1, 2] * u.TeV})
with pytest.raises(WrongColumnUnit):
check_table(t, required_units={'bar': u.m})

t = QTable({'bar': [0, 1, 2] * u.m})
with pytest.raises(MissingColumns):
check_table(t, required_units={'foo': u.m})
# works with Table as well as QTable
check_table(Table({"foo": [1]}), required_columns=["foo"])
check_table(Table({"foo": [1] * u.m}), required_units={"foo": u.m})
check_table(QTable({"foo": [1] * u.m}), required_units={"foo": u.m})

t = Table({"bar": [0, 1, 2] * u.TeV})

with pytest.raises(
MissingColumns,
match="Table is missing required columns {'foo'}",
):
check_table(t, required_columns=["foo"])

t = Table({"bar": [0, 1, 2] * u.TeV})
with pytest.raises(
WrongColumnUnit,
match='Column "bar" has incompatible unit "TeV", expected "m".',
):
check_table(t, required_units={"bar": u.m})

t = Table({"bar": [0, 1, 2] * u.m})
with pytest.raises(
MissingColumns,
match="Table is missing required columns foo",
):
check_table(t, required_units={"foo": u.m})

# m is convertible
check_table(t, required_units={'bar': u.cm})
check_table(t, required_units={"bar": u.cm})

t = Table({"bar": [0, 1, 2]})
with pytest.raises(
WrongColumnUnit,
match='Column "bar" has incompatible unit "None", expected "cm".',
):
check_table(t, required_units={"bar": u.cm})


def test_calculate_theta():
from pyirf.utils import calculate_theta

true_az = true_alt = u.Quantity([1.0], u.deg)
t = QTable({"reco_alt": true_alt, "reco_az": true_az})

assert u.isclose(
calculate_theta(
events=t,
assumed_source_az=true_az,
assumed_source_alt=true_alt,
),
0.0 * u.deg,
)

t = Table({"reco_alt": [1.0], "reco_az": [1.0]})
with pytest.raises(
WrongColumnUnit,
match='Column "reco_az" has incompatible unit "None", expected "deg".',
):
calculate_theta(t, true_az, true_alt)


def test_calculate_source_fov_offset():
from pyirf.utils import calculate_source_fov_offset

a = u.Quantity([1.0], u.deg)
t = QTable(
{
"pointing_az": a,
"pointing_alt": a,
"true_az": a,
"true_alt": a,
}
)

assert u.isclose(calculate_source_fov_offset(t), 0.0 * u.deg)


def test_check_histograms():
from pyirf.binning import create_histogram_table
from pyirf.utils import check_histograms

events1 = QTable(
{
"reco_energy": [1, 1, 10, 100, 100, 100] * u.TeV,
}
)
events2 = QTable(
{
"reco_energy": [100, 100, 100] * u.TeV,
}
)
bins = [0.5, 5, 50, 500] * u.TeV

hist1 = create_histogram_table(events1, bins)
hist2 = create_histogram_table(events2, bins)
check_histograms(hist1, hist2)

hist3 = create_histogram_table(events1, [0, 10] * u.TeV)
with pytest.raises(ValueError):
check_histograms(hist1, hist3)
19 changes: 15 additions & 4 deletions pyirf/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numpy as np
import astropy.units as u
import numpy as np
from astropy.coordinates.angle_utilities import angular_separation
from astropy.table import QTable

from .exceptions import MissingColumns, WrongColumnUnit


__all__ = [
"is_scalar",
"calculate_theta",
Expand Down Expand Up @@ -50,6 +50,7 @@ def calculate_theta(events, assumed_source_az, assumed_source_alt):
Angular separation between the assumed and reconstructed positions
in the sky.
"""
check_table(events, required_units={"reco_az": u.deg, "reco_alt": u.deg})
theta = angular_separation(
assumed_source_az,
assumed_source_alt,
Expand Down Expand Up @@ -78,6 +79,15 @@ def calculate_source_fov_offset(events, prefix="true"):
Angular separation between the true and pointing positions
in the sky.
"""
check_table(
events,
required_units={
f"{prefix}_az": u.deg,
f"{prefix}_alt": u.deg,
"pointing_az": u.deg,
"pointing_alt": u.deg,
},
)
theta = angular_separation(
events[f"{prefix}_az"],
events[f"{prefix}_alt"],
Expand Down Expand Up @@ -133,7 +143,7 @@ def check_table(table, required_columns=None, required_units=None):

Parameters
----------
table: astropy.table.QTable
table: astropy.table.Table
Table to check
required_columns: iterable[str]
Column names that are required to be present
Expand All @@ -147,6 +157,7 @@ def check_table(table, required_columns=None, required_units=None):
as keys in ``required_units are`` not present in the table.
WrongColumnUnit: if any column has the wrong unit
"""
table = QTable(table)
if required_columns is not None:
missing = set(required_columns) - set(table.colnames)
if missing:
Expand All @@ -158,5 +169,5 @@ def check_table(table, required_columns=None, required_units=None):
raise MissingColumns(col)

unit = table[col].unit
if not expected.is_equivalent(unit):
if not unit or not expected.is_equivalent(unit):
HealthyPear marked this conversation as resolved.
Show resolved Hide resolved
nbiederbeck marked this conversation as resolved.
Show resolved Hide resolved
raise WrongColumnUnit(col, unit, expected)