Skip to content

Commit

Permalink
Merge pull request #20 from alphasentaurii/svm-pred-report
Browse files Browse the repository at this point in the history
fix/svm-pred-encoding
  • Loading branch information
alphasentaurii authored May 3, 2022
2 parents 5032843 + 0aba0ae commit 473743a
Show file tree
Hide file tree
Showing 13 changed files with 622 additions and 139 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,19 @@ $ cd spacekit
$ pip install -e .
```

*Testing*

See `tox.ini` for a list of test suite markers.

```bash
# run all tests
$ pytest

# some tests, like the `scan` module rely on the test `env` option
$ pytest --env svm -m scan
$ pytest --env cal -m scan
```


### Pre-Trained Neural Nets

Expand Down
40 changes: 27 additions & 13 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from spacekit.analyzer.explore import HstCalPlots, HstSvmPlots
from spacekit.analyzer.scan import SvmScanner, CalScanner, import_dataset
from spacekit.extractor.load import load_datasets

# try:
# from pytest_astropy_header.display import (PYTEST_HEADER_MODULES,
# TESTED_VERSIONS)
Expand Down Expand Up @@ -50,7 +51,9 @@ def __init__(self, env):
"cal": os.path.join(f"tests/data/{env}/data.zip"),
}[env]

self.kwargs = {"svm": dict(index_col="index"), "cal": dict(index_col="ipst")}[env]
self.kwargs = {"svm": dict(index_col="index"), "cal": dict(index_col="ipst")}[
env
]

self.decoder = {
"svm": {"det": {0: "hrc", 1: "ir", 2: "sbc", 3: "uvis", 4: "wfc"}},
Expand All @@ -68,17 +71,30 @@ def __init__(self, env):
}[env]

self.norm_cols = {
"svm": ["numexp", "rms_ra", "rms_dec", "nmatches", "point", "segment", "gaia"],
"svm": [
"numexp",
"rms_ra",
"rms_dec",
"nmatches",
"point",
"segment",
"gaia",
],
"cal": ["n_files", "total_mb"],
}[env]
self.rename_cols = {
"svm": "_scl",
"cal": ["x_files", "x_size"]
}[env]
self.rename_cols = {"svm": "_scl", "cal": ["x_files", "x_size"]}[env]

self.enc_cols = {
"svm": ["det", "wcs", "cat"],
"cal": ["drizcorr", "pctecorr", "crsplit", "subarray", "detector", "dtype", "instr"]
"cal": [
"drizcorr",
"pctecorr",
"crsplit",
"subarray",
"detector",
"dtype",
"instr",
],
}[env]

self.tx_file = {
Expand Down Expand Up @@ -112,8 +128,8 @@ def res_data_path(cfg, tmp_path_factory):
data_path = os.path.join(basepath, dname)
return data_path

@fixture(scope='session')

@fixture(scope="session")
def df_ncols(cfg):
fname = cfg.labeled
X_cols = cfg.norm_cols + cfg.enc_cols
Expand All @@ -135,9 +151,7 @@ def scanner(cfg, res_data_path):
@fixture(scope="session")
def explorer(cfg, res_data_path):
fname = res_data_path
df = import_dataset(
filename=fname, kwargs=cfg.kwargs, decoder=cfg.decoder
)
df = import_dataset(filename=fname, kwargs=cfg.kwargs, decoder=cfg.decoder)
if cfg.env == "svm":
hst = HstSvmPlots(df)
elif cfg.env == "cal":
Expand Down Expand Up @@ -244,4 +258,4 @@ def scraped_mast_file():
# CAL
@fixture(scope="function")
def cal_labeled_dataset():
return "tests/data/cal/train/training.csv"
return "tests/data/cal/train/training.csv"
Loading

0 comments on commit 473743a

Please sign in to comment.