From 0aba0ae53574da62bda71a372a42107b2bf5ad6b Mon Sep 17 00:00:00 2001 From: alphasentaurii Date: Mon, 2 May 2022 20:14:11 -0400 Subject: [PATCH] black formatting --- conftest.py | 38 +++++--- docs/source/conf.py | 147 +++++++++++++++-------------- setup.py | 8 +- spacekit/preprocessor/encode.py | 82 ++++++++-------- spacekit/skopes/hst/svm/predict.py | 20 +++- tests/preprocessor/test_encode.py | 82 +++++++++++----- 6 files changed, 226 insertions(+), 151 deletions(-) diff --git a/conftest.py b/conftest.py index dd45ffa..6e1e5f1 100644 --- a/conftest.py +++ b/conftest.py @@ -5,6 +5,7 @@ from spacekit.analyzer.explore import HstCalPlots, HstSvmPlots from spacekit.analyzer.scan import SvmScanner, CalScanner, import_dataset from spacekit.extractor.load import load_datasets + # try: # from pytest_astropy_header.display import (PYTEST_HEADER_MODULES, # TESTED_VERSIONS) @@ -50,7 +51,9 @@ def __init__(self, env): "cal": os.path.join(f"tests/data/{env}/data.zip"), }[env] - self.kwargs = {"svm": dict(index_col="index"), "cal": dict(index_col="ipst")}[env] + self.kwargs = {"svm": dict(index_col="index"), "cal": dict(index_col="ipst")}[ + env + ] self.decoder = { "svm": {"det": {0: "hrc", 1: "ir", 2: "sbc", 3: "uvis", 4: "wfc"}}, @@ -68,17 +71,30 @@ def __init__(self, env): }[env] self.norm_cols = { - "svm": ["numexp", "rms_ra", "rms_dec", "nmatches", "point", "segment", "gaia"], + "svm": [ + "numexp", + "rms_ra", + "rms_dec", + "nmatches", + "point", + "segment", + "gaia", + ], "cal": ["n_files", "total_mb"], }[env] - self.rename_cols = { - "svm": "_scl", - "cal": ["x_files", "x_size"] - }[env] + self.rename_cols = {"svm": "_scl", "cal": ["x_files", "x_size"]}[env] self.enc_cols = { "svm": ["det", "wcs", "cat"], - "cal": ["drizcorr", "pctecorr", "crsplit", "subarray", "detector", "dtype", "instr"] + "cal": [ + "drizcorr", + "pctecorr", + "crsplit", + "subarray", + "detector", + "dtype", + "instr", + ], }[env] self.tx_file = { @@ -112,8 +128,8 @@ def res_data_path(cfg, tmp_path_factory): data_path = os.path.join(basepath, dname) return data_path - -@fixture(scope='session') + +@fixture(scope="session") def df_ncols(cfg): fname = cfg.labeled X_cols = cfg.norm_cols + cfg.enc_cols @@ -135,9 +151,7 @@ def scanner(cfg, res_data_path): @fixture(scope="session") def explorer(cfg, res_data_path): fname = res_data_path - df = import_dataset( - filename=fname, kwargs=cfg.kwargs, decoder=cfg.decoder - ) + df = import_dataset(filename=fname, kwargs=cfg.kwargs, decoder=cfg.decoder) if cfg.env == "svm": hst = HstSvmPlots(df) elif cfg.env == "cal": diff --git a/docs/source/conf.py b/docs/source/conf.py index 46bb5e1..2f3d470 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,28 +12,31 @@ # import os import sys -#import sphinx + +# import sphinx import datetime import os from configparser import ConfigParser -#import stsci_rtd_theme + +# import stsci_rtd_theme # -- Project information ----------------------------------------------------- # General information about the project -project = u'spacekit' -author = u'Ru Kein' +project = "spacekit" +author = "Ru Kein" year = datetime.datetime.now().year -copyright = f'{year}, {author}' +copyright = f"{year}, {author}" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # build documents. # The full version, including alpha/beta/rc tags. -release = '0.2.7' -#release = get_distribution(project).version +release = "0.2.7" +# release = get_distribution(project).version # The short X.Y version. -version = '.'.join(release.split('.')[:2]) +version = ".".join(release.split(".")[:2]) + def setup(app): app.add_css_file("stsci.css") @@ -41,17 +44,17 @@ def setup(app): # -- General configuration ------------------------------------------------ conf = ConfigParser() -conf.read([os.path.join(os.path.dirname(__file__), '..', 'setup.cfg')]) +conf.read([os.path.join(os.path.dirname(__file__), "..", "setup.cfg")]) # setup_cfg = dict(conf.items('metadata')) # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # sys.path.insert(0, os.path.abspath('.')) -sys.path.insert(0, os.path.abspath('../')) -sys.path.insert(0, os.path.abspath('../../')) +sys.path.insert(0, os.path.abspath("../")) +sys.path.insert(0, os.path.abspath("../../")) -#on_rtd = os.environ.get('READTHEDOCS', None) == 'True' +# on_rtd = os.environ.get('READTHEDOCS', None) == 'True' # Configuration for intersphinx: refer to the Python standard library. # Uncomment if you cross-ref to API doc from other packages. @@ -66,46 +69,46 @@ def setup(app): # (None, 'http://data.astropy.org/intersphinx/matplotlib.inv')), # noqa # 'astropy': ('https://docs.astropy.org/en/stable/', None)} intersphinx_mapping = { - 'python': ('http://docs.python.org/3/', None), - 'numpy': ('http://docs.scipy.org/doc/numpy/', None), - 'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None), - 'matplotlib': ('http://matplotlib.org/', None), - 'astropy': ('http://docs.astropy.org/en/stable/', None), + "python": ("http://docs.python.org/3/", None), + "numpy": ("http://docs.scipy.org/doc/numpy/", None), + "scipy": ("http://docs.scipy.org/doc/scipy/reference/", None), + "matplotlib": ("http://matplotlib.org/", None), + "astropy": ("http://docs.astropy.org/en/stable/", None), } # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.inheritance_diagram', - 'sphinx.ext.viewcode', - 'sphinx.ext.napoleon', - 'sphinx.ext.autosummary', - 'sphinx.ext.doctest', - 'sphinx.ext.coverage', - 'sphinx.ext.mathjax', - 'numpydoc', - 'sphinx_automodapi.automodapi', - 'sphinx_automodapi.automodsumm', - 'sphinx_automodapi.autodoc_enhancements', - 'sphinx_automodapi.smart_resolver', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.inheritance_diagram", + "sphinx.ext.viewcode", + "sphinx.ext.napoleon", + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.coverage", + "sphinx.ext.mathjax", + "numpydoc", + "sphinx_automodapi.automodapi", + "sphinx_automodapi.automodsumm", + "sphinx_automodapi.autodoc_enhancements", + "sphinx_automodapi.smart_resolver", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. # source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -119,11 +122,11 @@ def setup(app): # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all # documents. -default_role = 'obj' +default_role = "obj" # Don't show summaries of the members in each class along with the # class' docstring @@ -131,7 +134,7 @@ def setup(app): autosummary_generate = True -automodapi_toctreedirnm = 'api' +automodapi_toctreedirnm = "api" # Class documentation should contain *both* the class docstring and # the __init__ docstring @@ -141,12 +144,12 @@ def setup(app): graphviz_output_format = "svg" graphviz_dot_args = [ - '-Nfontsize=10', - '-Nfontname=Helvetica Neue, Helvetica, Arial, sans-serif', - '-Efontsize=10', - '-Efontname=Helvetica Neue, Helvetica, Arial, sans-serif', - '-Gfontsize=10', - '-Gfontname=Helvetica Neue, Helvetica, Arial, sans-serif' + "-Nfontsize=10", + "-Nfontname=Helvetica Neue, Helvetica, Arial, sans-serif", + "-Efontsize=10", + "-Efontname=Helvetica Neue, Helvetica, Arial, sans-serif", + "-Gfontsize=10", + "-Gfontname=Helvetica Neue, Helvetica, Arial, sans-serif", ] # If true, '()' will be appended to :func: etc. cross-reference text. @@ -161,7 +164,7 @@ def setup(app): # show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. # modindex_common_prefix = [] @@ -176,7 +179,7 @@ def setup(app): # a list of builtin themes. # html_theme = 'stsci_rtd_theme' # html_theme_path = [stsci_rtd_theme.get_html_theme_path()] -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". @@ -194,18 +197,18 @@ def setup(app): # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. # html_extra_path = [] -html_static_path = ['_static'] +html_static_path = ["_static"] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -html_last_updated_fmt = '%b %d, %Y' +html_last_updated_fmt = "%b %d, %Y" # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. # html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -html_sidebars = {'**': ['globaltoc.html', 'relations.html', 'searchbox.html']} +html_sidebars = {"**": ["globaltoc.html", "relations.html", "searchbox.html"]} # Additional templates that should be rendered to pages, maps page names to # template names. @@ -238,25 +241,24 @@ def setup(app): # html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = f'{project}doc' +htmlhelp_basename = f"{project}doc" # -- Options for LaTeX output --------------------------------------------- latex_elements = { # The paper size ('letterpaper' or 'a4paper'). - 'papersize': 'letterpaper', + "papersize": "letterpaper", # The font size ('10pt', '11pt' or '12pt'). - 'pointsize': '14pt', + "pointsize": "14pt", # Additional stuff for the LaTeX preamble. - 'preamble': r'''\usepackage{enumitem} \setlistdepth{99}''' + "preamble": r"""\usepackage{enumitem} \setlistdepth{99}""", } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - ('index', f'{project}.tex', f'{project} Documentation', - f'{project}', 'manual'), + ("index", f"{project}.tex", f"{project} Documentation", f"{project}", "manual"), ] # The name of an image file (relative to this directory) to place at the top of @@ -270,7 +272,7 @@ def setup(app): # latex_show_pagerefs = False # If true, show URL addresses after external links. -latex_show_urls = 'True' +latex_show_urls = "True" # Documents to append as an appendix to all manuals. # latex_appendices = [] @@ -282,10 +284,7 @@ def setup(app): # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - ('index', f'{project}', f'{project} Documentation', - [f'{project}'], 1) -] +man_pages = [("index", f"{project}", f"{project} Documentation", [f"{project}"], 1)] # If true, show URL addresses after external links. man_show_urls = True @@ -296,9 +295,15 @@ def setup(app): # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', f'{project}', f'{project} Documentation', - f'{author}', f'{project}', f'{project}', - 'Miscellaneous'), + ( + "index", + f"{project}", + f"{project} Documentation", + f"{author}", + f"{project}", + f"{project}", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. @@ -308,7 +313,7 @@ def setup(app): texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -texinfo_show_urls = 'inline' +texinfo_show_urls = "inline" # If true, do not generate a @detailmenu in the "Top" node's menu. # texinfo_no_detailmenu = False @@ -316,10 +321,10 @@ def setup(app): # -- Options for Epub output ---------------------------------------------- # Bibliographic Dublin Core info. -epub_title = f'{project}' -epub_author = f'{author}' -epub_publisher = f'{author}' -epub_copyright = f'{year} {author}' +epub_title = f"{project}" +epub_author = f"{author}" +epub_publisher = f"{author}" +epub_copyright = f"{year} {author}" # The basename for the epub file. It defaults to the project name. # epub_basename = u'wfc3tools' @@ -328,7 +333,7 @@ def setup(app): # optimized for small screen space, using the same theme for HTML and # epub output is usually not wise. This defaults to 'epub', a theme designed # to save visual space. -epub_theme = 'epub' +epub_theme = "epub" # The language of the text. It defaults to the language option # or en if the language is not set. @@ -359,7 +364,7 @@ def setup(app): # epub_post_files = [] # A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # The depth of the table of contents in toc.ncx. # epub_tocdepth = 3 @@ -380,4 +385,4 @@ def setup(app): # epub_show_urls = 'inline' # If false, no index is generated. -# epub_use_index = True \ No newline at end of file +# epub_use_index = True diff --git a/setup.py b/setup.py index 2c1a799..420f5f3 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ """ -if 'test' in sys.argv: +if "test" in sys.argv: print(TEST_HELP) sys.exit(1) @@ -29,10 +29,10 @@ """ -if 'build_docs' in sys.argv or 'build_sphinx' in sys.argv: +if "build_docs" in sys.argv or "build_sphinx" in sys.argv: print(DOCS_HELP) sys.exit(1) -#setup(use_scm_version={'write_to': 'spacekit/_version.py'}) -setup(version='0.3.0') +# setup(use_scm_version={'write_to': 'spacekit/_version.py'}) +setup(version="0.3.0") diff --git a/spacekit/preprocessor/encode.py b/spacekit/preprocessor/encode.py index 594a86c..33f1b69 100644 --- a/spacekit/preprocessor/encode.py +++ b/spacekit/preprocessor/encode.py @@ -3,6 +3,7 @@ from tensorflow.keras.utils import to_categorical import numpy as np + def encode_target_data(y_train, y_test): """Label encodes target class training and test data for multi-classification models. @@ -47,17 +48,17 @@ def lambda_func(self, inverse=False): self.inverse_pairs() I = lambda i: self.invpairs[i] return [I(b) for b in self.transformed] - + def inverse_pairs(self): self.invpairs = {} for key, value in self.keypairs.items(): self.invpairs[value] = key return self.invpairs - + def warn_unknowns(self): unknowns = np.where([a not in self.classes_ for a in self.arr]) print(f"WARNING: Found unknown values:\n {self.arr[unknowns]}") - + def handle_unknowns(self): unknowns = np.where([a not in self.classes_ for a in self.arr]) add_encoding = max(list(self.keypairs.values())) + 1 @@ -73,7 +74,9 @@ def handle_unknowns(self): def fit(self, data, keypairs, axiscol=None, handle_unknowns=True): if isinstance(data, pd.DataFrame): if axiscol is None: - print("Error: Must indicate which column to fit if `data` is a `dataframe`.") + print( + "Error: Must indicate which column to fit if `data` is a `dataframe`." + ) return try: self.arr = np.asarray(data[axiscol], dtype=object) @@ -95,7 +98,7 @@ def fit(self, data, keypairs, axiscol=None, handle_unknowns=True): self.keypairs = keypairs self.classes_ = list(self.keypairs.keys()) if self.arr.any() not in self.classes_: - #if self.arr.any() not in self.classes_: + # if self.arr.any() not in self.classes_: self.warn_unknowns() if handle_unknowns is True: self.handle_unknowns() @@ -106,14 +109,14 @@ def fit(self, data, keypairs, axiscol=None, handle_unknowns=True): except Exception as e: print(e) return self - + def transform(self): if self.arr is None: print("Error - Must fit the data first.") return self.transformed = self.lambda_func() return self.transformed - + def inverse_transform(self): inverse_pairs = {} for key, value in self.keypairs.items(): @@ -121,7 +124,7 @@ def inverse_transform(self): # TODO handle unknowns/nans inversely self.inversed = self.lambda_func(inverse=True) return self.inversed - + def fit_transform(self, data, keypairs, axiscol=None): self.fit(data, keypairs, axiscol=axiscol) self.transform() @@ -130,17 +133,22 @@ def fit_transform(self, data, keypairs, axiscol=None): class SvmEncoder: """Categorical encoding class for HST Single Visit Mosiac regression test data inputs.""" - def __init__(self, data, fkeys=['category', 'detector', 'wcstype'], names=['cat','det','wcs']): + def __init__( + self, + data, + fkeys=["category", "detector", "wcstype"], + names=["cat", "det", "wcs"], + ): """Instantiates an SvmEncoder class object. Parameters ---------- data : dataframe input data containing features (columns) to be encoded - + fkeys: list categorical-type column names (str) to be encoded - + names: list new names to assign columns of the encoded versions of categorical data @@ -150,9 +158,12 @@ def __init__(self, data, fkeys=['category', 'detector', 'wcstype'], names=['cat' self.names = names self.df = self.categorical_data() self.make_keypairs() - + def __repr__(self): - return 'encodings: %s \n category_keys: %s \n detector_keys: %s \n wcs_keys: %s' % (self.encodings, self.category_keys, self.detector_keys, self.wcs_keys) + return ( + "encodings: %s \n category_keys: %s \n detector_keys: %s \n wcs_keys: %s" + % (self.encodings, self.category_keys, self.detector_keys, self.wcs_keys) + ) def categorical_data(self): """Makes a copy of input dataframe and extracts only the categorical features based on the column names in `fkeys`. @@ -163,10 +174,9 @@ def categorical_data(self): dataframe with only the categorical feature columns """ return self.data.copy()[self.fkeys] - + def make_keypairs(self): - """Instantiates key-pair dictionaries for each of the categorical features. - """ + """Instantiates key-pair dictionaries for each of the categorical features.""" self.encodings = dict(zip(self.fkeys, self.names)) self.category_keys = self.set_category_keys() self.detector_keys = self.set_detector_keys() @@ -192,7 +202,7 @@ def init_categories(self): "EXT-STAR": "S", "CLUSTER OF GALAXIES": "GC", "GALAXY": "G", - "None": "U" + "None": "U", } def set_category_keys(self): @@ -205,16 +215,16 @@ def set_category_keys(self): """ self.category_keys = { "C": 0, - "SS":1, - "I":2, - "U":3, - "SC":4, - "S":5, - "GC":6, - "G":7, + "SS": 1, + "I": 2, + "U": 3, + "SC": 4, + "S": 5, + "GC": 6, + "G": 7, } return self.category_keys - + def set_detector_keys(self): """Assigns a hardcoded integer to each 'detector' key in alphabetical and increasing value. @@ -223,15 +233,9 @@ def set_detector_keys(self): dict detector names and their associated integer encoding """ - self.detector_keys = { - "hrc": 0, - "ir": 1, - "sbc": 2, - "uvis": 3, - "wfc": 4 - } + self.detector_keys = {"hrc": 0, "ir": 1, "sbc": 2, "uvis": 3, "wfc": 4} return self.detector_keys - + def set_wcs_keys(self): """Assigns a hardcoded integer to each 'wcs' key in alphabetical and increasing value. @@ -244,10 +248,10 @@ def set_wcs_keys(self): "a posteriori": 0, "a priori": 1, "default a": 2, - "not aligned": 3 + "not aligned": 3, } return self.wcs_keys - + def svm_keypairs(self, column): keypairs = { "category": self.category_keys, @@ -256,7 +260,7 @@ def svm_keypairs(self, column): } return keypairs[column] - def encode_categories(self, cname='category', sep=';'): + def encode_categories(self, cname="category", sep=";"): """Transforms the raw string inputs from MAST target category naming conventions into an abbreviated form. For example, `CLUSTER OF GALAXIES;GRAVITATIONA` becomes `GC` for galaxy cluster; and `STELLAR CLUSTER;GLOBULAR CLUSTER` becomes `SC` for stellar cluster. This serves to group similar but differently named objects into a discrete set of 8 possible categorizations. The 8 categories will then be encoded into integer values in the final encoding step (machine learning inputs must be numeric). Returns @@ -274,7 +278,7 @@ def encode_categories(self, cname='category', sep=';'): self.df.drop("category", axis=1, inplace=True) self.df = self.df.join(df_cat, how="left") return self.df - + def rejoin_original(self): originals = list(self.encodings.keys()) self.df.drop(originals, axis=1, inplace=True) @@ -296,6 +300,8 @@ def encode_features(self): enc.fit_transform(self.df, keypairs, axiscol=col) self.df[name] = enc.transformed print(f"\n*** {col} --> {name} ***") - print(f"ORIGINAL:\n{self.df[col].value_counts()}\n\nENCODED:\n{self.df[name].value_counts()}\n") + print( + f"ORIGINAL:\n{self.df[col].value_counts()}\n\nENCODED:\n{self.df[name].value_counts()}\n" + ) self.rejoin_original() return self.df diff --git a/spacekit/skopes/hst/svm/predict.py b/spacekit/skopes/hst/svm/predict.py index 6de2bdc..1bd85d1 100644 --- a/spacekit/skopes/hst/svm/predict.py +++ b/spacekit/skopes/hst/svm/predict.py @@ -29,7 +29,8 @@ TF_CPP_MIN_LOG_LEVEL = 2 -DETECTOR_KEY = {"hrc":0, "ir":1, "sbc":2, "uvis":3, "wfc":4} +DETECTOR_KEY = {"hrc": 0, "ir": 1, "sbc": 2, "uvis": 3, "wfc": 4} + def load_mixed_inputs(data_file, img_path, tx=None, size=128, norm=0): """Load the regression test data and image input data, then stacks the arrays into a single combined input (list) for the ensemble model. @@ -114,7 +115,10 @@ def classification_report(df, output_path, group=None): print("Alignment Evaluation") print("0.0=aligned, 1.0=suspicious") cnt_pct = pd.concat( - [P.value_counts(), P.value_counts(normalize=True), ], + [ + P.value_counts(), + P.value_counts(normalize=True), + ], axis=1, keys=["cnt", "pct"], ) @@ -177,7 +181,13 @@ def classify_alignments(X, model, output_path=None, group=None): def predict_alignment( - data_file, img_path, model_path=None, output_path=None, size=128, norm=0, group=None, + data_file, + img_path, + model_path=None, + output_path=None, + size=128, + norm=0, + group=None, ): """Main calling function to load the data and model, generate predictions, and save results to disk. @@ -251,7 +261,7 @@ def predict_alignment( "--group", type=str, default=None, - help="Name for this group of data (to be included in classification report)" + help="Name for this group of data (to be included in classification report)", ) args = parser.parse_args() _ = predict_alignment( @@ -261,5 +271,5 @@ def predict_alignment( output_path=args.output_path, size=args.size, norm=args.normalization, - group=args.group + group=args.group, ) diff --git a/tests/preprocessor/test_encode.py b/tests/preprocessor/test_encode.py index 9728940..7e4d747 100644 --- a/tests/preprocessor/test_encode.py +++ b/tests/preprocessor/test_encode.py @@ -42,29 +42,31 @@ def test_svm_encoder(scraped_mast_file): assert enc.df.det[0] == 1 assert enc.df.wcs[0] == 0 + @mark.svm @mark.preprocessor @mark.encode def test_pair_encoder_unknown_val(scraped_mast_file): data = pd.read_csv(scraped_mast_file, index_col="index") - data.loc['hst_12286_38_wfc3_ir_total_ibl738', 'wcstype'] = 'NaN' - keypairs = {'a posteriori': 0, 'a priori': 1, 'default a': 2, 'not aligned': 3} + data.loc["hst_12286_38_wfc3_ir_total_ibl738", "wcstype"] = "NaN" + keypairs = {"a posteriori": 0, "a priori": 1, "default a": 2, "not aligned": 3} enc = PairEncoder() - enc.fit(data, keypairs, axiscol='wcstype', handle_unknowns=False) + enc.fit(data, keypairs, axiscol="wcstype", handle_unknowns=False) try: enc.transform() except KeyError: assert True + @mark.svm @mark.preprocessor @mark.encode def test_svm_encoder_handle_unknown(scraped_mast_file): data = pd.read_csv(scraped_mast_file, index_col="index") - data.loc['hst_12286_38_wfc3_ir_total_ibl738', 'wcstype'] = 'NaN' - keypairs = {'a posteriori': 0, 'a priori': 1, 'default a': 2, 'not aligned': 3} + data.loc["hst_12286_38_wfc3_ir_total_ibl738", "wcstype"] = "NaN" + keypairs = {"a posteriori": 0, "a priori": 1, "default a": 2, "not aligned": 3} enc = PairEncoder() - enc.fit(data, keypairs, axiscol='wcstype', handle_unknowns=True) + enc.fit(data, keypairs, axiscol="wcstype", handle_unknowns=True) enc.transform() assert enc.transformed[0] == 4 @@ -74,7 +76,7 @@ def test_svm_encoder_handle_unknown(scraped_mast_file): @mark.encode def test_pair_encoder_unspecified_column(scraped_mast_file): data = pd.read_csv(scraped_mast_file, index_col="index") - keypairs = {'a posteriori': 0, 'a priori': 1, 'default a': 2, 'not aligned': 3} + keypairs = {"a posteriori": 0, "a priori": 1, "default a": 2, "not aligned": 3} enc = PairEncoder() enc.fit(data, keypairs) assert enc.arr is None @@ -83,34 +85,77 @@ def test_pair_encoder_unspecified_column(scraped_mast_file): except AttributeError: assert True + @mark.svm @mark.preprocessor @mark.encode def test_pair_encoder_array_1d(): - keypairs = {'a posteriori': 0, 'a priori': 1, 'default a': 2, 'not aligned': 3} - arr1d = asarray(['a priori'], dtype=object) + keypairs = {"a posteriori": 0, "a priori": 1, "default a": 2, "not aligned": 3} + arr1d = asarray(["a priori"], dtype=object) enc = PairEncoder() enc.fit(arr1d, keypairs) enc.transform() assert enc.transformed[0] == 1 + @mark.svm @mark.preprocessor @mark.encode def test_pair_encoder_array_2d(): - keypairs = {'a posteriori': 0, 'a priori': 1, 'default a': 2, 'not aligned': 3} - arr2d = asarray([['ir', 'ibl738', 'ANY', 262.46, 52.32, 2, 'myfile.fits', 284, 185, 7, 5.87, 13.38, 5, 'default a','UNIDENTIFIED;PARALLEL FIELD']]) + keypairs = {"a posteriori": 0, "a priori": 1, "default a": 2, "not aligned": 3} + arr2d = asarray( + [ + [ + "ir", + "ibl738", + "ANY", + 262.46, + 52.32, + 2, + "myfile.fits", + 284, + 185, + 7, + 5.87, + 13.38, + 5, + "default a", + "UNIDENTIFIED;PARALLEL FIELD", + ] + ] + ) enc = PairEncoder() enc.fit(arr2d, keypairs, axiscol=13) enc.transform() assert enc.transformed[0] == 2 + @mark.svm @mark.preprocessor @mark.encode def test_pair_encoder_array_2d_unspecified_axis(): - keypairs = {'a posteriori': 0, 'a priori': 1, 'default a': 2, 'not aligned': 3} - arr2d = asarray([['ir', 'ibl738', 'ANY', 262.46, 52.32, 2, 'myfile.fits', 284, 185, 7, 5.87, 13.38, 5, 'default a','UNIDENTIFIED;PARALLEL FIELD']]) + keypairs = {"a posteriori": 0, "a priori": 1, "default a": 2, "not aligned": 3} + arr2d = asarray( + [ + [ + "ir", + "ibl738", + "ANY", + 262.46, + 52.32, + 2, + "myfile.fits", + 284, + 185, + 7, + 5.87, + 13.38, + 5, + "default a", + "UNIDENTIFIED;PARALLEL FIELD", + ] + ] + ) enc = PairEncoder() enc.fit(arr2d, keypairs) try: @@ -118,18 +163,13 @@ def test_pair_encoder_array_2d_unspecified_axis(): except AttributeError: assert True + @mark.svm @mark.preprocessor @mark.encode def test_pair_encoder_inverse_transform(scraped_mast_file): - data = asarray(['ir', 'ir', 'uvis', 'wfc'], dtype=object) - detector_keys = { - "hrc": 0, - "ir": 1, - "sbc": 2, - "uvis": 3, - "wfc": 4 - } + data = asarray(["ir", "ir", "uvis", "wfc"], dtype=object) + detector_keys = {"hrc": 0, "ir": 1, "sbc": 2, "uvis": 3, "wfc": 4} enc = PairEncoder() enc.fit(data, detector_keys) enc.transform()