diff --git a/Plot_Table3.ipynb b/Plot_Table3.ipynb new file mode 100644 index 0000000..f9c9fef --- /dev/null +++ b/Plot_Table3.ipynb @@ -0,0 +1,220 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import pickle\n", + "\n", + "import matplotlib as mpl\n", + "# print(mpl.rcParams.items)\n", + "mpl.use('Agg')\n", + "mpl.rcParams['text.usetex'] = False\n", + "mpl.rcParams['mathtext.rm'] = 'serif'\n", + "mpl.rcParams['font.family'] = 'serif'\n", + "mpl.rcParams['font.serif'] = ['Times New Roman']\n", + "# mpl.rcParams['font.family'] = ['Times New Roman']\n", + "mpl.rcParams['axes.titlesize'] = 25\n", + "mpl.rcParams['axes.labelsize'] = 20\n", + "mpl.rcParams['xtick.labelsize'] = 15\n", + "mpl.rcParams['ytick.labelsize'] = 15\n", + "mpl.rcParams['savefig.dpi'] = 250\n", + "mpl.rcParams['figure.dpi'] = 250\n", + "mpl.rcParams['savefig.format'] = 'pdf'\n", + "mpl.rcParams['savefig.bbox'] = 'tight'\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):\n", + " new_cmap = mpl.colors.LinearSegmentedColormap.from_list(\n", + " 'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),\n", + " cmap(np.linspace(minval, maxval, n)))\n", + " return new_cmap\n", + "\n", + "cmap = plt.get_cmap('hot_r')\n", + "fave_cmap = truncate_colormap(cmap, 0.35, 1.0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metric_dictionary = {'TBDT':{'FoM': 1, 'LogLoss': 1, 'Brier': 1},\n", + " 'TKNN':{'FoM': 7, 'LogLoss': 6, 'Brier': 7},\n", + " 'TNB':{'FoM': 8, 'LogLoss': 9, 'Brier': 8},\n", + " 'TNN':{'FoM': 5, 'LogLoss': 3, 'Brier': 3},\n", + " 'TSVM':{'FoM': 3, 'LogLoss': 2, 'Brier': 2},\n", + " 'WBDT':{'FoM': 2, 'LogLoss': 5, 'Brier': 4},\n", + " 'WKNN':{'FoM': 9, 'LogLoss': 8, 'Brier': 9},\n", + " 'WNB':{'FoM': 10, 'LogLoss': 10, 'Brier': 10},\n", + " 'WNN':{'FoM': 6, 'LogLoss': 7, 'Brier': 6},\n", + " 'WSVM':{'FoM': 4, 'LogLoss': 4, 'Brier': 5},\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metric_dictionary['TBDT']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "symbols = {'TBDT':'o',\n", + " 'TKNN':'d',\n", + " 'TNB':'s',\n", + " 'TNN':'*',\n", + " 'TSVM':'^',\n", + " 'WBDT':'o',\n", + " 'WKNN':'d',\n", + " 'WNB':'s',\n", + " 'WNN':'*',\n", + " 'WSVM':'^',\n", + " }\n", + "\n", + "colors = {'TBDT':fave_cmap(0.05),\n", + " 'TKNN':fave_cmap(0.3),\n", + " 'TNB':fave_cmap(0.55),\n", + " 'TNN':fave_cmap(0.8),\n", + " 'TSVM':fave_cmap(1.0),\n", + " 'WBDT':fave_cmap(0.05),\n", + " 'WKNN':fave_cmap(0.3),\n", + " 'WNB':fave_cmap(0.55),\n", + " 'WNN':fave_cmap(0.75),\n", + " 'WSVM':fave_cmap(1.0),\n", + " }\n", + "\n", + "\n", + "plt.figure()\n", + "for key, value in metric_dictionary.items():\n", + " val = []\n", + " for k, v in value.items():\n", + " val.append(v)\n", + " if 'W' in key:\n", + " plt.plot(val, label=key, marker=symbols[key], ls='--', color=colors[key])\n", + " else:\n", + " plt.plot(val, label=key, marker=symbols[key], color=colors[key])\n", + "\n", + "plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), prop={'size': 12})\n", + "plt.xticks([0, 1, 2], ['FoM', 'LogLoss', 'Brier'])\n", + "plt.yticks(np.arange(1, 11))\n", + "plt.ylabel('Rank')\n", + "\n", + "#plt.savefig('Tables3_option1.pdf')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "colors = {'TBDT':fave_cmap(0.05),\n", + " 'TKNN':fave_cmap(0.2375),\n", + " 'TNB':fave_cmap(0.54),\n", + " 'TNN':fave_cmap(0.712499999),\n", + " 'TSVM':fave_cmap(1.0),\n", + " 'WBDT':fave_cmap(0.05),\n", + " 'WKNN':fave_cmap(0.2375),\n", + " 'WNB':fave_cmap(0.54),\n", + " 'WNN':fave_cmap(0.712499999),\n", + " 'WSVM':fave_cmap(1.0),\n", + " }\n", + "\n", + "plt.figure()\n", + "for key, value in metric_dictionary.items():\n", + " val = []\n", + " for k, v in value.items():\n", + " val.append(v)\n", + " if 'W' in key:\n", + " plt.plot(val, label=key, marker=symbols[key], ls='--', color=colors[key], lw=2, ms=7, alpha=0.3)\n", + " else:\n", + " plt.plot(val, label=key, marker=symbols[key], color=colors[key], lw=2, ms=7)\n", + "\n", + "plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), prop={'size': 12})\n", + "plt.xticks([0, 1, 2], ['FoM', 'LogLoss', 'Brier'])\n", + "plt.yticks(np.arange(1, 11))\n", + "plt.ylabel('Rank')\n", + "plt.gca().invert_yaxis()\n", + "\n", + "#plt.savefig('Tables3_option4.pdf')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure()\n", + "\n", + "fom = []\n", + "ll = []\n", + "brier = []\n", + "\n", + "for key, value in metric_dictionary.items():\n", + " fom.append(value['FoM'])\n", + " ll.append(value['LogLoss'])\n", + " brier.append(value['Brier'])\n", + "\n", + "plt.plot(fom, label='FoM', marker='o')\n", + "plt.plot(ll, label='LogLoss', marker='D', alpha = 0.5)\n", + "plt.plot(brier, label='Brier', marker='s', alpha=0.23)\n", + "\n", + "plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), prop={'size': 12})\n", + "plt.xticks(np.arange(0, 10), list(metric_dictionary.keys()), rotation=45)\n", + "plt.ylabel('Rank')\n", + "plt.savefig('Tables3_option2.pdf')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/paper/authors.csv b/paper/authors.csv index 44f07c7..d381072 100644 --- a/paper/authors.csv +++ b/paper/authors.csv @@ -1,11 +1,12 @@ Lastname,Firstname,Authorname,AuthorType,Affiliation,Contribution,Email +Malz,Alex,A.I.~Malz,Contact,"German Centre of Cosmological Lensing, Ruhr-Universitaet Bochum, Universitaetsstra{\ss}e 150, 44801 Bochum, Germany","conceptualization, data curation, formal analysis, investigation, methodology, project administration, software, supervision, validation, visualization, writing - editing, writing - original draft",aimalz@nyu.edu Malz,Alex,A.I.~Malz,Contact,"Center for Cosmology and Particle Physics, New York University, 726 Broadway, New York, NY 10004, USA","conceptualization, data curation, formal analysis, investigation, methodology, project administration, software, supervision, validation, visualization, writing - editing, writing - original draft",aimalz@nyu.edu Malz,Alex,A.I.~Malz,Contact,"Department of Physics, New York University, 726 Broadway, New York, NY 10004, USA","conceptualization, data curation, formal analysis, investigation, methodology, project administration, software, supervision, validation, visualization, writing - editing, writing - original draft",aimalz@nyu.edu Hlo\v{z}ek,Ren\'ee,R.~Hlo\v{z}ek,Contributor,"Department of Astronomy and Astrophysics, University of Toronto, 50 St. George St., Toronto, ON M5S 3H4, Canada","data curation, formal analysis, funding acquisition, investigation, project administration, software, supervision, validation, visualization, writing - editing, writing - original draft",hlozek@dunlap.utoronto.ca Hlo\v{z}ek,Ren\'ee,R.~Hlo\v{z}ek,Contributor,"Dunlap Institute for Astronomy and Astrophysics, University of Toronto, 50 St. George St., Toronto, ON M5S 3H4, Canada","data curation, formal analysis, funding acquisition, investigation, project administration, software, supervision, validation, visualization, writing - editing, writing - original draft",hlozek@dunlap.utoronto.ca Allam,Tarek,T.~Allam Jr,Contributor,"Mullard Space Science Laboratory, Department of Space and Climate Physics, University College London, Holmbury Hill Rd, Dorking RH5 6NT, UK","investigation, software, validation, writing - original draft",[email] Bahmanyar,Anita,A.~Bahmanyar,Contributor,"Dunlap Institute for Astronomy and Astrophysics, University of Toronto, 50 St. George St., Toronto, ON M5S 3H4, Canada","formal analysis, investigation, methodology, software, writing - editing, writing - original draft",[email] -Biswas,Rahul,R.~Biswas,Contributor,"The Oskar Klein Centre for Cosmoparticle Physics, Stockholm University, AlbaNova, Stockholm, SE-106 91, Sweden","conceptualization, methodology, software, writing - original draft",[email] +Biswas,Rahul,R.~Biswas,Contributor,"The Oskar Klein Centre for Cosmoparticle Physics, Stockholm University, AlbaNova, Stockholm, SE-106 91, Sweden","conceptualization, methodology, software, supervision, writing - editing, writing - original draft",[email] Dai,Mi,M.~Dai,Contributor,"Rutgers, the State University of New Jersey, 136 Frelinghuysen Road, Piscataway, NJ 08854 USA","writing - editing",[email] Galbany,Llu\'is,L.~Galbany,Contributor,"University of Pittsburgh, 300 Allen Hall, 3941 O'Hara St, Pittsburgh, PA 15260","writing - editing",[email] Ishida,Emille,E.E.O.~Ishida,Contributor,"Universit\'e Clermont Auvergne, CNRS/IN2P3, LPC, F-63000 Clermont-Ferrand, France","conceptualization, project administration, supervision, writing - editing",[email] @@ -25,5 +26,5 @@ Narayan,Gautham,G.~Narayan,Contributor,"Space Telescope Science Institute, 3700 Peiris,Hiranya,H.~Peiris,Contributor,"The Oskar Klein Centre for Cosmoparticle Physics, Stockholm University, AlbaNova, Stockholm, SE-106 91, Sweden","conceptualization, funding acquisition, supervision",[email] Peiris,Hiranya,H.~Peiris,Contributor,"Department of Physics and Astronomy, University College London, Gower Street, London, WC1E 6BT, UK","conceptualization, funding acquisition, supervision",[email] Peters,Christina~M.,C.M.~Peters,Contributor,"Dunlap Institute for Astronomy and Astrophysics, University of Toronto, 50 St. George St., Toronto, ON M5S 3H4, Canada","writing - editing",[email] -Ponder,Kara,K.~Ponder,Contributor,"Berkeley Center for Cosmological Physics, Campbell Hall 341, University of California Berkeley, Berkeley, CA 94720, USA","writing - editing",[email] +Ponder,Kara,K.~Ponder,Contributor,"Berkeley Center for Cosmological Physics, Campbell Hall 341, University of California Berkeley, Berkeley, CA 94720, USA","visualization, writing - editing",[email] Setzer,Christian,C.N.~Setzer,Contributor,"The Oskar Klein Centre for Cosmoparticle Physics, Stockholm University, AlbaNova, Stockholm, SE-106 91, Sweden","conceptualization, software",christian.setzer@fysik.su.se diff --git a/paper/fig/Tables3_option4.png b/paper/fig/Tables3_option4.png new file mode 100644 index 0000000..4120ab6 Binary files /dev/null and b/paper/fig/Tables3_option4.png differ diff --git a/paper/fig/all_sim_cm.png b/paper/fig/all_sim_cm.png index 4d055d3..331c646 100644 Binary files a/paper/fig/all_sim_cm.png and b/paper/fig/all_sim_cm.png differ diff --git a/paper/fig/all_snphotcc_cm.png b/paper/fig/all_snphotcc_cm.png index c60b00b..9d86051 100644 Binary files a/paper/fig/all_snphotcc_cm.png and b/paper/fig/all_snphotcc_cm.png differ diff --git a/paper/fig/combined.png b/paper/fig/combined.png index ab860fb..ce502a4 100644 Binary files a/paper/fig/combined.png and b/paper/fig/combined.png differ diff --git a/paper/fig/examples.png b/paper/fig/examples.png index 5510a8e..b70aaf4 100644 Binary files a/paper/fig/examples.png and b/paper/fig/examples.png differ diff --git a/paper/hacks.ipynb b/paper/hacks.ipynb index 2ec1b65..ed992f6 100644 --- a/paper/hacks.ipynb +++ b/paper/hacks.ipynb @@ -622,7 +622,7 @@ "weight_vecs = np.array([[i] + [(1. - i) / (M_classes - 1.)] * (M_classes - 1) for i in possible_weights])\n", "which_weight_schemes = {str(i): weight_vecs[i] for i in range(len(possible_weights))}\n", "\n", - "alt_mega_test = load_collector('fig/test'+str(M_classes)+'_fromcmdm.pkl')" + "# alt_mega_test = load_collector('fig/test'+str(M_classes)+'_fromcmdm.pkl')" ] }, { @@ -1452,17 +1452,17 @@ "source": [ "# connect lines along systematic, weighting, and affected class\n", "def wt_only_plot(dataset, metric_info, shapes, style='rel'):\n", - " baselines = dataset.keys()\n", + " baselines = list(dataset.keys())\n", " fig = pylab.figure(figsize=(10.2, 10.))\n", " bigAxes = pylab.axes(frameon=False) # hide frame\n", " bigAxes.set_xticks([]) # don't want to see any ticks on this axis\n", " bigAxes.set_yticks([])\n", " bigAxesP = bigAxes.twinx()\n", " bigAxesP.set_yticks([])\n", - " bigAxes.set_ylabel(metric_info.keys()[0], fontsize=20, labelpad=25, color=metric_info.values()[0])\n", + " bigAxes.set_ylabel(list(metric_info.keys())[0], fontsize=20, labelpad=25, color=list(metric_info.values())[0])\n", " bigAxes.set_xlabel(style+r'. weight on class', fontsize=20, labelpad=25)\n", - " bigAxesP.set_ylabel(metric_info.keys()[1], rotation=270, fontsize=20, \n", - " labelpad=50, color=metric_info.values()[1])\n", + " bigAxesP.set_ylabel(list(metric_info.keys())[1], rotation=270, fontsize=20, \n", + " labelpad=50, color=list(metric_info.values())[1])\n", " bigAxes.set_title('tunnel on baselines')\n", " for si in range(len(baselines)):\n", " s = baselines[si]\n", @@ -1493,12 +1493,12 @@ " ax.text(.5, .9, s, \n", " horizontalalignment='center',\n", " transform=ax.transAxes, fontsize=20)\n", - " ax.plot(wts[style].T[0], dataset[s][metric_info.keys()[0]],\n", + " ax.plot(wts[style].T[0], dataset[s][list(metric_info.keys())[0]],\n", " marker=shapes[s],\n", - " alpha=0.5, c=metric_info.values()[0])\n", - " axp.plot(wts[style].T[0], dataset[s][metric_info.keys()[1]],\n", + " alpha=0.5, c=list(metric_info.values())[0])\n", + " axp.plot(wts[style].T[0], dataset[s][list(metric_info.keys())[1]],\n", " marker=shapes[s],\n", - " alpha=0.5, c=metric_info.values()[1])\n", + " alpha=0.5, c=list(metric_info.values())[1])\n", " ax.set_ylim(-0.05, 2.75)\n", " axp.set_ylim(-0.001, 0.081)\n", "# ax.set_xlim(-0.25, 2.25)\n", @@ -2071,15 +2071,22 @@ "# jupyter nbconvert desc_note/main.ipynb --TagRemovePreprocessor.remove_cell_tags='[\"hideme\"]'\n", "# jupyter nbconvert desc_note/main.ipynb --TagRemovePreprocessor.remove_input_tags='[\"hidein\"]'\n" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "anaconda-cloud": {}, "celltoolbar": "Tags", "kernelspec": { - "display_name": "ProClaM (Python 3)", - "language": "python", - "name": "proclam_3" + "display_name": "Python 3", + "language": "python3", + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -2091,7 +2098,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.5" + "version": "3.6.8" } }, "nbformat": 4, diff --git a/paper/kaggle-run.ipynb b/paper/kaggle-run.ipynb deleted file mode 100644 index f8ce066..0000000 --- a/paper/kaggle-run.ipynb +++ /dev/null @@ -1,721 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# The Photometric LSST Astronomical Time-series Classification Challenge (PLAsTiCC): code that runs the performance metric\n", - "\n", - "*Alex Malz (NYU)*, *Renee Hlozek (U. Toronto)*, *Tarek Alam (UCL)*, *Anita Bahmanyar (U. Toronto)*, *Rahul Biswas (U. Stockholm)*, *Rafael Martinez-Galarza (Harvard)*, *Gautham Narayan (STScI)*" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import pandas as pd\n", - "\n", - "# This is the code available on GitHub for calculating metrics,\n", - "# as well as performing other diagnostics on probability tables." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "def make_class_pairs(data_info_dict):\n", - " \"\"\"\n", - " Paris the paths to classifier output and truth tables for each classifier.\n", - "\n", - " Parameters\n", - " ----------\n", - " data_info_dict: dictionary\n", - " \n", - " Returns\n", - " -------\n", - " data_info_dict: dictionary\n", - " updated keywords: class_pairs, dict - classifier: [path_to_class_output, path_to_truth_tables] \n", - " \"\"\"\n", - " \n", - " for name in data_info_dict['names']:\n", - " data_info_dict['class_pairs'][name] = [data_info_dict['classifications'][name], data_info_dict['truth_tables'][name]]\n", - " \n", - " return(data_info_dict['class_pairs'])\n", - " \n", - "def make_file_locs(data_info_dict):\n", - " \"\"\"\n", - " Set paths to data directory, classifier output and truth tables.\n", - "\n", - " Parameters\n", - " ----------\n", - " data_info_dict: dictionary \n", - " \n", - " Returns\n", - " -------\n", - " data_info_dict: dictionary\n", - " updated keywords: dirname - data directory, str\n", - " classifications, dict - classifier: path to classifier output, str\n", - " truth_tables, dict - classifier: path to truth tables - str\n", - " \"\"\"\n", - " \n", - " # get the names of classifiers to be considered\n", - " names = data_info_dict['names']\n", - " \n", - " # set data directory\n", - " data_info_dict['dirname'] = topdir + data_info_dict['label'] + '/'\n", - "\n", - " for name in names:\n", - " # get the path to classifier output\n", - " data_info_dict['classifications'][name] = '%s/predicted_prob_%s.csv'%(name, name)\n", - " \n", - " # get the path to truth table\n", - " data_info_dict['truth_tables'][name] = '%s/truth_table_%s.csv'%(name, name)\n", - " \n", - " return data_info_dict\n", - "\n", - "def process_strings(dataset, cc):\n", - " \"\"\"\n", - " Get info on directory name and classifier.\n", - "\n", - " Parameters\n", - " ----------\n", - " dataset: dictionary \n", - " cc: classifier name, str\n", - " \n", - " Returns\n", - " -------\n", - " loc: data directory, str\n", - " text: version label, str\n", - " \"\"\"\n", - " \n", - " loc = dataset['dirname']\n", - " text = dataset['label'] + ' ' + dataset['names'][cc]\n", - " \n", - " return loc, text\n", - "\n", - "def just_read_class_pairs(pair, dataset, cc):\n", - " \"\"\"\n", - " Reads predicted probabilities and truth table.\n", - "\n", - " Parameters\n", - " ----------\n", - " pair: list of str - [path_to_classifier_output, path_to_truth_table]\n", - " dataset: dictionary\n", - " cc: classifier name, str\n", - " \n", - " Returns\n", - " -------\n", - " prob_mat: probability matrix (output from classifier)\n", - " tvec: truth vector\n", - " \"\"\"\n", - " \n", - " loc, text = process_strings(dataset, cc)\n", - " clfile = pair[0]\n", - " truthfile = pair[1]\n", - " \n", - " # read classifier output\n", - " prob_mat = pd.read_csv(loc + clfile, delim_whitespace=True).values\n", - " \n", - " # read truth table\n", - " truth_values = pd.read_csv(loc + truthfile, delim_whitespace=True).values\n", - " tvec = np.where(truth_values==1)[1]\n", - " \n", - " return prob_mat, tvec" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "# Build dictionary to store classification results\n", - "mystery = {}\n", - "mystery['label'] = 'Unknown'\n", - "mystery['names'] = ['RandomForest', 'KNeighbors', 'MLPNeuralNet']\n", - "mystery['classifications'] = {}\n", - "mystery['truth_tables'] = {}\n", - "mystery['class_pairs'] = {}\n", - "mystery['probs'] = {}\n", - "mystery['truth'] = {}" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "# Read classifier output and truth tables\n", - "topdir = '../examples/'\n", - "mystery = make_file_locs(mystery)\n", - "mystery['class_pairs'] = make_class_pairs(mystery)\n", - "for nm, name in enumerate(mystery['names']):\n", - " probm, truthv = just_read_class_pairs(mystery['class_pairs'][name], mystery, nm)\n", - " mystery['probs'][name] = probm\n", - " mystery['truth'][name] = truthv\n", - "M_classes = np.shape(probm)[-1]\n", - "\n", - "# we need the class labels in the dataset in a consistently sorted order \n", - "# and will assume the weights of the weightvec correspond to this order\n", - "class_labels = sorted(np.unique(mystery['truth']['RandomForest']))" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'RandomForest': ['RandomForest/predicted_prob_RandomForest.csv',\n", - " 'RandomForest/truth_table_RandomForest.csv'],\n", - " 'KNeighbors': ['KNeighbors/predicted_prob_KNeighbors.csv',\n", - " 'KNeighbors/truth_table_KNeighbors.csv'],\n", - " 'MLPNeuralNet': ['MLPNeuralNet/predicted_prob_MLPNeuralNet.csv',\n", - " 'MLPNeuralNet/truth_table_MLPNeuralNet.csv']}" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "topdir = '../examples/'\n", - "mystery = make_file_locs(mystery)\n", - "mystery['class_pairs'] = make_class_pairs(mystery)\n", - "mystery['class_pairs']" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Method (Metric)\n", - "======\n", - "\n", - "The log-loss is defined as\n", - "\\begin{eqnarray*}\n", - "L &=& -\\sum_{m=1}^{M}\\frac{w_{m}}{N_{m}}\\sum_{n=1}^{N_{m}}\\ln[p_{n}(m | m)]\n", - "\\end{eqnarray*}\n", - "\n", - "We calculate the metric within each class $m$ by taking an average of its value $-\\ln[p_{n}(m | m)]$ for each true member $n$ of the class. Then we weight the metrics for each class by an arbitrary weight $w_{m}$ and take a weighted average of the per-class metrics to produce a global scalar metric $L$." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "\"\"\"\n", - "This is all from proclam but copied here so no one has to install it.\n", - "\"\"\"\n", - "\n", - "import numpy as np\n", - "import sys\n", - "import collections\n", - "\n", - "\"\"\"\n", - "Utility functions for PLAsTiCC metrics\n", - "\"\"\"\n", - "\n", - "# from __future__ import absolute_import, division\n", - "# __all__ = ['sanitize_predictions',\n", - "# 'weight_sum', 'averager', 'check_weights',\n", - "# 'det_to_prob',\n", - "# 'prob_to_det',\n", - "# 'det_to_cm', 'prob_to_cm',\n", - "# 'cm_to_rate', 'det_to_rate', 'prob_to_rate']\n", - "\n", - "RateMatrix = collections.namedtuple('rates', 'TPR FPR FNR TNR')\n", - "\n", - "def sanitize_predictions(predictions, epsilon=sys.float_info.epsilon):\n", - " \"\"\"\n", - " Replaces 0 and 1 with 0+epsilon, 1-epsilon\n", - "\n", - " Parameters\n", - " ----------\n", - " predictions: numpy.ndarray, float\n", - " N*M matrix of probabilities per object, may have 0 or 1 values\n", - " epsilon: float\n", - " small placeholder number, defaults to floating point precision\n", - "\n", - " Returns\n", - " -------\n", - " predictions: numpy.ndarray, float\n", - " N*M matrix of probabilities per object, no 0 or 1 values\n", - " \"\"\"\n", - " assert epsilon > 0. and epsilon < 0.0005\n", - " mask1 = (predictions < epsilon)\n", - " mask2 = (predictions > 1.0 - epsilon)\n", - "\n", - " predictions[mask1] = epsilon\n", - " predictions[mask2] = 1.0 - epsilon\n", - " predictions = predictions / np.sum(predictions, axis=1)[:, np.newaxis]\n", - " return predictions\n", - "\n", - "def det_to_prob(dets, prediction=None):\n", - " \"\"\"\n", - " Reformats vector of class assignments into matrix with 1 at true/assigned class and zero elsewhere\n", - "\n", - " Parameters\n", - " ----------\n", - " dets: numpy.ndarray, int\n", - " vector of classes\n", - " prediction: numpy.ndarray, float, optional\n", - " predicted class probabilities\n", - "\n", - " Returns\n", - " -------\n", - " probs: numpy.ndarray, float\n", - " matrix with 1 at input classes and 0 elsewhere\n", - "\n", - " Notes\n", - " -----\n", - " det_to_prob formerly truth_reformatter\n", - " Does not yet handle number of classes in truth not matching number of classes in prediction, i.e. for having \"other\" class or secret classes not in training set. The prediction keyword is a kludge to enable this but should be replaced.\n", - " \"\"\"\n", - " N = len(dets)\n", - " indices = range(N)\n", - "\n", - " if prediction is None:\n", - " prediction_shape = (N, int(np.max(dets) + 1))\n", - " else:\n", - " prediction, dets = np.asarray(prediction), np.asarray(dets)\n", - " prediction_shape = np.shape(prediction)\n", - "\n", - " probs = np.zeros(prediction_shape)\n", - " probs[indices, dets] = 1.\n", - "\n", - " return probs\n", - "\n", - "def prob_to_det(probs):\n", - " \"\"\"\n", - " Converts probabilistic classifications to deterministic classifications by assigning the class with highest probability\n", - "\n", - " Parameters\n", - " ----------\n", - " probs: numpy.ndarray, float\n", - " N * M matrix of class probabilities\n", - "\n", - " Returns\n", - " -------\n", - " dets: numpy.ndarray, int\n", - " maximum probability classes\n", - " \"\"\"\n", - " dets = np.argmax(probs, axis=1)\n", - "\n", - " return dets\n", - "\n", - "def det_to_cm(dets, truth, per_class_norm=True, vb=False):\n", - " \"\"\"\n", - " Converts deterministic classifications and truth into confusion matrix\n", - "\n", - " Parameters\n", - " ----------\n", - " dets: numpy.ndarray, int\n", - " assigned classes\n", - " truth: numpy.ndarray, int\n", - " true classes\n", - " per_class_norm: boolean, optional\n", - " equal weight per class if True, equal weight per object if False\n", - " vb: boolean, optional\n", - " if True, print cm\n", - "\n", - " Returns\n", - " -------\n", - " cm: numpy.ndarray, int\n", - " confusion matrix\n", - "\n", - " Notes\n", - " -----\n", - " I need to fix the norm keyword all around to enable more options, like normed output vs. not.\n", - " \"\"\"\n", - " pred_classes, pred_counts = np.unique(dets, return_counts=True)\n", - " true_classes, true_counts = np.unique(truth, return_counts=True)\n", - " if vb: print((pred_classes, pred_counts), (true_classes, true_counts))\n", - "\n", - " M = max(max(pred_classes), max(true_classes)) + 1\n", - "\n", - " cm = np.zeros((M, M), dtype=float)\n", - " # print((np.shape(dets), np.shape(truth)))\n", - " coords = np.array(list(zip(dets, truth)))\n", - " indices, index_counts = np.unique(coords, axis=0, return_counts=True)\n", - " # if vb: print(indices, index_counts)\n", - " indices = indices.T\n", - " # if vb: print(np.shape(indices))\n", - " cm[indices[0], indices[1]] = index_counts\n", - " if vb: print(cm)\n", - "\n", - " if per_class_norm:\n", - " # print(type(cm))\n", - " # print(type(true_counts))\n", - " # cm = cm / true_counts\n", - " # cm /= true_counts[:, np.newaxis] #\n", - " cm = cm / true_counts[np.newaxis, :]\n", - "\n", - " if vb: print(cm)\n", - "\n", - " return cm\n", - "\n", - "def prob_to_cm(probs, truth, per_class_norm=True, vb=False):\n", - " \"\"\"\n", - " Turns probabilistic classifications into confusion matrix by taking maximum probability as deterministic class\n", - "\n", - " Parameters\n", - " ----------\n", - " probs: numpy.ndarray, float\n", - " N * M matrix of class probabilities\n", - " truth: numpy.ndarray, int\n", - " N-dimensional vector of true classes\n", - " per_class_norm: boolean, optional\n", - " equal weight per class if True, equal weight per object if False\n", - " vb: boolean, optional\n", - " if True, print cm\n", - "\n", - " Returns\n", - " -------\n", - " cm: numpy.ndarray, int\n", - " confusion matrix\n", - " \"\"\"\n", - " dets = prob_to_det(probs)\n", - "\n", - " cm = det_to_cm(dets, truth, per_class_norm=per_class_norm, vb=vb)\n", - "\n", - " return cm\n", - "\n", - "def cm_to_rate(cm, vb=False):\n", - " \"\"\"\n", - " Turns a confusion matrix into true/false positive/negative rates\n", - "\n", - " Parameters\n", - " ----------\n", - " cm: numpy.ndarray, int or float\n", - " confusion matrix, first axis is predictions, second axis is truth\n", - " vb: boolean, optional\n", - " print progress to stdout?\n", - "\n", - " Returns\n", - " -------\n", - " rates: named tuple, float\n", - " RateMatrix named tuple\n", - "\n", - " Notes\n", - " -----\n", - " BROKEN!\n", - " This can be done with a mask to weight the classes differently here.\n", - " \"\"\"\n", - " if vb: print(cm)\n", - " diag = np.diag(cm)\n", - " if vb: print(diag)\n", - "\n", - " TP = np.sum(diag)\n", - " FN = np.sum(np.sum(cm, axis=0) - diag)\n", - " FP = np.sum(np.sum(cm, axis=1) - diag)\n", - " TN = np.sum(cm) - TP\n", - " if vb: print((TP, FN, FP, TN))\n", - "\n", - " T = TP + TN\n", - " F = FP + FN\n", - " P = TP + FP\n", - " N = TN + FN\n", - " if vb: print((T, F, P, N))\n", - "\n", - " TPR = TP / P\n", - " FPR = FP / N\n", - " FNR = FN / P\n", - " TNR = TN / N\n", - "\n", - " rates = RateMatrix(TPR=TPR, FPR=FPR, FNR=FNR, TNR=TNR)\n", - " if vb: print(rates)\n", - "\n", - " return rates\n", - "\n", - "def det_to_rate(dets, truth, per_class_norm=True, vb=False):\n", - " cm = det_to_cm(dets, truth, per_class_norm=per_class_norm, vb=vb)\n", - " rates = cm_to_rate(cm, vb=vb)\n", - " return rates\n", - "\n", - "def prob_to_rate(probs, truth, per_class_norm=True, vb=False):\n", - " cm = prob_to_cm(probs, truth, per_class_norm=per_class_norm, vb=vb)\n", - " rates = cm_to_rate(cm, vb=vb)\n", - " return rates\n", - "\n", - "def weight_sum(per_class_metrics, weight_vector, norm=True):\n", - " \"\"\"\n", - " Calculates the weighted metric\n", - "\n", - " Parameters\n", - " ----------\n", - " per_class_metrics: numpy.float\n", - " the scores separated by class (a list of arrays)\n", - " weight_vector: numpy.ndarray floar\n", - " The array of weights per class\n", - " norm: boolean, optional\n", - "\n", - " Returns\n", - " -------\n", - " weight_sum: np.float\n", - " The weighted metric\n", - " \"\"\"\n", - " weight_sum = np.dot(weight_vector, per_class_metrics)\n", - "\n", - " return weight_sum\n", - "\n", - "def check_weights(avg_info, M, truth=None):\n", - " \"\"\"\n", - " Converts standard weighting schemes to weight vectors for weight_sum\n", - "\n", - " Parameters\n", - " ----------\n", - " avg_info: str or numpy.ndarray, float\n", - " keyword about how to calculate weighted average metric\n", - " M: int\n", - " number of classes\n", - " truth: numpy.ndarray, int, optional\n", - " true class assignments\n", - "\n", - " Returns\n", - " -------\n", - " weights: numpy.ndarray, float\n", - " relative weights per class\n", - " \"\"\"\n", - " if type(avg_info) != str:\n", - " avg_info = np.asarray(avg_info)\n", - " weights = avg_info / np.sum(avg_info)\n", - " assert(np.isclose(sum(weights), 1.))\n", - " elif avg_info == 'per_class':\n", - " weights = np.ones(M) / float(M)\n", - " elif avg_info == 'per_item':\n", - " classes, counts = np.unique(truth, return_counts=True)\n", - " weights = np.zeros(M)\n", - " weights[classes] = counts / float(len(truth))\n", - " assert len(weights) == M\n", - " return weights\n", - "\n", - "def averager(per_object_metrics, truth, M):\n", - " \"\"\"\n", - " Creates a list with the metrics per object, separated by class\n", - " \"\"\"\n", - " group_metric = per_object_metrics\n", - " class_metric = np.empty(M)\n", - " for m in range(M):\n", - " true_indices = np.where(truth == m)[0]\n", - " how_many_in_class = len(true_indices)\n", - " try:\n", - " assert(how_many_in_class > 0)\n", - " per_class_metric = group_metric[true_indices]\n", - " # assert(~np.all(np.isnan(per_class_metric)))\n", - " class_metric[m] = np.average(per_class_metric)\n", - " except AssertionError:\n", - " class_metric[m] = 0.\n", - " # print((m, how_many_in_class, class_metric[m]))\n", - " return class_metric\n", - "\n", - "\"\"\"\n", - "A superclass for metrics\n", - "\"\"\"\n", - "class Metric(object):\n", - "\n", - " def __init__(self, scheme=None, **kwargs):\n", - " \"\"\"\n", - " An object that evaluates a function of the true classes and class probabilities\n", - "\n", - " Parameters\n", - " ----------\n", - " scheme: string\n", - " the name of the metric\n", - " \"\"\"\n", - " self.scheme = scheme\n", - "\n", - " def evaluate(self, prediction, truth, weights=None, **kwds):\n", - " \"\"\"\n", - " Evaluates a function of the truth and prediction\n", - "\n", - " Parameters\n", - " ----------\n", - " prediction: numpy.ndarray, float\n", - " predicted class probabilities\n", - " truth: numpy.ndarray, int\n", - " true classes\n", - " weights: numpy.ndarray, float\n", - " per-class weights\n", - "\n", - " Returns\n", - " -------\n", - " metric: float\n", - " value of the metric\n", - " \"\"\"\n", - " print('No metric specified: returning true positive rate based on maximum value')\n", - "\n", - " return # metric\n", - "\n", - "\"\"\"\n", - "A metric subclass for the log-loss\n", - "\"\"\"\n", - "class LogLoss(Metric):\n", - " def __init__(self, scheme=None):\n", - " \"\"\"\n", - " An object that evaluates the log-loss metric\n", - "\n", - " Parameters\n", - " ----------\n", - " scheme: string\n", - " the name of the metric\n", - " \"\"\"\n", - " super(LogLoss, self).__init__(scheme)\n", - " self.scheme = scheme\n", - "\n", - " def evaluate(self, prediction, truth, averaging='per_class'):\n", - " \"\"\"\n", - " Evaluates the log-loss\n", - "\n", - " Parameters\n", - " ----------\n", - " prediction: numpy.ndarray, float\n", - " predicted class probabilities\n", - " truth: numpy.ndarray, int\n", - " true classes\n", - " averaging: string or numpy.ndarray, float\n", - " 'per_class' weights classes equally, other keywords possible\n", - " vector assumed to be class weights\n", - "\n", - " Returns\n", - " -------\n", - " logloss: float\n", - " value of the metric\n", - "\n", - " Notes\n", - " -----\n", - " This uses the natural log.\n", - " \"\"\"\n", - " prediction, truth = np.asarray(prediction), np.asarray(truth)\n", - " prediction_shape = np.shape(prediction)\n", - " (N, M) = prediction_shape\n", - "\n", - " weights = check_weights(averaging, M, truth=truth)\n", - " truth_mask = det_to_prob(truth, prediction)\n", - "\n", - " prediction = sanitize_predictions(prediction)\n", - "\n", - " log_prob = np.log(prediction)\n", - " logloss_each = -1. * np.sum(truth_mask * log_prob, axis=1)[:, np.newaxis]\n", - "\n", - " # use a better structure for checking keyword support\n", - " class_logloss = averager(logloss_each, truth, M)\n", - "\n", - " logloss = weight_sum(class_logloss, weight_vector=weights)\n", - "\n", - " assert(~np.isnan(logloss))\n", - "\n", - " return logloss\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "KNeighbors with weights [0.05882353 0.11764706 0.05882353 0.05882353 0.05882353 0.05882353\n", - " 0.05882353 0.05882353 0.05882353 0.05882353 0.11764706 0.11764706\n", - " 0.11764706] has LogLoss = 20.749255306361132\n" - ] - } - ], - "source": [ - "# This is how you run the metric with a random weight vector.\n", - "\n", - "metric = 'LogLoss'\n", - "weightvec = np.ones(M_classes)\n", - "\n", - "# dummy example for SNPhotCC demo data\n", - "special_classes = (1, 10, 11, 12)\n", - "\n", - "# we should be using this for the PLAsTiCC data\n", - "# special_clases = (51, 99)\n", - "\n", - "mask = np.array([True if classname in special_classes else False for classname in class_labels])\n", - "weightvec[mask] = 2\n", - "weightvec = weightvec / sum(weightvec)\n", - "name = np.random.choice(mystery['names'])\n", - "probm = mystery['probs'][name]\n", - "truthv = mystery['truth'][name]\n", - "LL = LogLoss()\n", - "val = LL.evaluate(prediction=probm, truth=truthv, averaging=weightvec)\n", - "print(name+' with weights '+str(weightvec)+' has '+metric+' = '+str(val))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Acknowledgments\n", - "===============\n", - "\n", - "The DESC acknowledges ongoing support from the Institut National de Physique Nucleaire et de Physique des Particules in France; the Science & Technology Facilities Council in the United Kingdom; and the Department of Energy, the National Science Foundation, and the LSST Corporation in the United States.\n", - "\n", - "DESC uses resources of the IN2P3 Computing Center (CC-IN2P3--Lyon/Villeurbanne - France) funded by the Centre National de la Recherche Scientifique; the National Energy Research Scientific Computing Center, a DOE Office of Science User Facility supported by the Office of Science of the U.S. Department of Energy under Contract No. DE-AC02-05CH11231; STFC DiRAC HPC Facilities, funded by UK BIS National E-infrastructure capital grants; and the UK particle physics grid, supported by the GridPP Collaboration.\n", - "\n", - "This work was performed in part under DOE Contract DE-AC02-76SF00515." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Contributions\n", - "=======\n", - "\n", - "Alex Malz: conceptualization, data curation, formal analysis, investigation, methodology, project administration, software, supervision, validation, visualization, writing - original draft\n", - "\n", - "Renee Hlozek: data curation, formal analysis, funding acquisition, investigation, project administration, software, supervision, validation, visualization, writing - original draft\n", - "\n", - "Tarek Alam: investigation, software, validation\n", - "\n", - "Anita Bahmanyar: formal analysis, investigation, methodology, software, writing - original draft\n", - "\n", - "Rahul Biswas: conceptualization, methodology, software\n", - "\n", - "Rafael Martinez-Galarza: data curation, software, visualization\n", - "\n", - "Gautham Narayan: data curation, formal analysis" - ] - } - ], - "metadata": { - "anaconda-cloud": {}, - "celltoolbar": "Tags", - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.8" - } - }, - "nbformat": 4, - "nbformat_minor": 1 -} diff --git a/paper/main.bib b/paper/main.bib index 862cd38..38b24f7 100644 --- a/paper/main.bib +++ b/paper/main.bib @@ -1,4 +1,64 @@ +@article{pedregosa_scikit-learn:_2011, + title = {Scikit-learn: {Machine} learning in {Python}}, + volume = {12}, + shorttitle = {Scikit-learn}, + number = {Oct}, + journal = {J Machine Learning Res}, + author = {Pedregosa, Fabian and Varoquaux, Ga{\"e}l and Gramfort, Alexandre and Michel, Vincent and Thirion, Bertrand and Grisel, Olivier and Blondel, Mathieu and Prettenhofer, Peter and Weiss, Ron and Dubourg, Vincent and {others}}, + year = {2011}, + pages = {2825--2830}, +} + +@book{oliphant_guide_2006, + address = {USA}, + title = {A guide to {NumPy}}, + publisher = {Trelgol Publishing}, + author = {Oliphant, Travis E.}, + year = {2006} +} + +@misc{jones_scipy:_2001, + title = {{SciPy}: {Open} {Source} {Scientific} {Tools} for {Python}}, + author = {Jones, Eric and Oliphant, Travis and Peterson, Pearu}, + year = {2001} +} + +@article{hunter_matplotlib:_2007, + title = {Matplotlib: {A} 2D {Graphics} {Environment}}, + volume = {9}, + issn = {1521-9615}, + shorttitle = {Matplotlib}, + doi = {10.1109/MCSE.2007.55}, + number = {3}, + journal = {Computing in Science Engineering}, + author = {Hunter, J. D.}, + month = may, + year = {2007}, + pages = {90--95}, +} + +@inproceedings{kluyver_jupyter_2016, + title = {Jupyter {Notebooks}-a publishing format for reproducible computational workflows.}, + booktitle = {{ELPUB}}, + author = {Kluyver, Thomas and Ragan-Kelley, Benjamin and P{\'e}rez, Fernando and Granger, Brian E. and Bussonnier, Matthias and Frederic, Jonathan and Kelley, Kyle and Hamrick, Jessica B. and Grout, Jason and Corlay, Sylvain}, + year = {2016}, + pages = {87--90}, +} + +@article{walt_numpy_2011, + title = {The {NumPy} {Array}: {A} {Structure} for {Efficient} {Numerical} {Computation}}, + volume = {13}, + issn = {1521-9615}, + shorttitle = {The {NumPy} {Array}}, + number = {2}, + journal = {Computing in Science \& Engineering}, + author = {Walt, S. v and Colbert, S. C. and Varoquaux, G.}, + year = {2011}, + doi = {10.1109/MCSE.2011.37}, + pages = {22--30}, +} + @article{kessler_supernova_2010, title = {Supernova {Photometric} {Classification} {Challenge}}, journal = {arXiv:1001.5210 [astro-ph]}, @@ -596,3 +656,87 @@ @article{richards_bayesian_2015 year = {2015}, pages = {39}, } + +@book{oliphant_python_2007, + title = {Python for {Scientific} {Computing}}, + volume = {9}, + author = {Oliphant, T.}, + month = may, + year = {2007}, + doi = {10.1109/MCSE.2007.58} +} + +@misc{malz_cosmological_2018, + title = {Cosmological {Hierarchical} {Inference} with {Probabilistic} {Photometric} {Redshifts}: aimalz/chippr}, + copyright = {MIT}, + shorttitle = {Cosmological {Hierarchical} {Inference} with {Probabilistic} {Photometric} {Redshifts}}, + author = {Malz, Alex I.}, + month = jul, + year = {2018}, +} + +@inproceedings{martin_det_1997, + title = {The {DET} curve in assessment of detection task performance}, + author = {Martin, Alvin F. and Doddington, George R. and Kamm, Terri and Ordowski, Mark and Przybocki, Mark A.}, + year = {1997}, +} + +@misc{malz_proclam_2018, + title = {{ProClaM}}, + author = {Malz, Alex I.}, + year = {2018}, + doi = {10.5281/zenodo.3352639} +} + +@article{buitinck_api_2013, + title = {{API} design for machine learning software: experiences from the scikit-learn project}, + shorttitle = {{API} design for machine learning software}, + journal = {arXiv:1309.0238 [cs]}, + author = {Buitinck, Lars and Louppe, Gilles and Blondel, Mathieu and Pedregosa, Fabian and Mueller, Andreas and Grisel, Olivier and Niculae, Vlad and Prettenhofer, Peter and Gramfort, Alexandre and Grobler, Jaques and Layton, Robert and Vanderplas, Jake and Joly, Arnaud and Holt, Brian and Varoquaux, Ga{\"e}l}, + month = sep, + year = {2013}, +} + +@phdthesis{bell_burnell_measurement_1969, + address = {Cambridge, UK}, + type = {Thesis}, + title = {The measurement of radio source diameters using a diffraction method}, + language = {en}, + school = {Department of Radio Astronomy, University of Cambridge}, + author = {Bell Burnell, Jocelyn}, + month = feb, + year = {1969}, + doi = {10.17863/CAM.4926}, +} + +@article{hewish_observation_1968, + title = {Observation of a {Rapidly} {Pulsating} {Radio} {Source}}, + volume = {217}, + copyright = {1968 Nature Publishing Group}, + issn = {1476-4687}, + doi = {10.1038/217709a0}, + language = {En}, + number = {5130}, + journal = {Nature}, + author = {Hewish, A. and Bell, S. J. and Pilkington, J. D. H. and Scott, P. F. and Collins, R. A.}, + month = feb, + year = {1968}, + pages = {709}, +} + +@article{the_plasticc_team_photometric_2018, + title = {The {Photometric} {LSST} {Astronomical} {Time}-series {Classification} {Challenge} ({PLAsTiCC}): {Data} set}, + shorttitle = {The {Photometric} {LSST} {Astronomical} {Time}-series {Classification} {Challenge} ({PLAsTiCC})}, + journal = {arXiv:1810.00001 [astro-ph]}, + author = {The PLAsTiCC team and Allam Jr., Tarek and Bahmanyar, Anita and Biswas, Rahul and Dai, Mi and Galbany, Llu{\'i}s and Hlo{\v z}ek, Ren{\'e}e and Ishida, Emille E. O. and Jha, Saurabh W. and Jones, David O. and Kessler, Richard and Lochner, Michelle and Mahabal, Ashish A. and Malz, Alex I. and Mandel, Kaisey S. and Mart{\'i}nez-Galarza, Juan Rafael and McEwen, Jason D. and Muthukrishna, Daniel and Narayan, Gautham and Peiris, Hiranya and Peters, Christina M. and Ponder, Kara and Setzer, Christian N. and Collaboration, The LSST Dark Energy Science and Transients, The LSST and Collaboration, Variable Stars Science}, + month = sep, + year = {2018}, +} + +@article{kessler_models_2019, + title = {Models and {Simulations} for the {Photometric} {LSST} {Astronomical} {Time} {Series} {Classification} {Challenge} ({PLAsTiCC})}, + journal = {arXiv:1903.11756 [astro-ph]}, + author = {Kessler, R. and Narayan, G. and Avelino, A. and Bachelet, E. and Biswas, R. and Brown, P. J. and Chernoff, D. F. and Connolly, A. J. and Dai, M. and Daniel, S. and Di Stefano, R. and Drout, M. R. and Galbany, L. and Gonz{\'a}lez-Gait{\'a}n, S. and Graham, M. L. and Hlo{\v z}ek, R. and Ishida, E. E. O. and Guillochon, J. and Jha, S. W. and Jones, D. O. and Mandel, K. S. and Muthukrishna, D. and O'Grady, A. and Peters, C. M. and Pierel, J. R. and Ponder, K. A. and Pr{\v s}a, A. and Rodney, S. and Villar, V. A.}, + month = mar, + year = {2019}, +} diff --git a/paper/main.ipynb b/paper/main.ipynb index 628f498..749dce6 100644 --- a/paper/main.ipynb +++ b/paper/main.ipynb @@ -28,12 +28,19 @@ "source": [ "import matplotlib as mpl\n", "# print(mpl.rcParams.items)\n", - "mpl.use('Agg')\n", + "mpl.use('PS')\n", "mpl.rcParams['text.usetex'] = False\n", "mpl.rcParams['mathtext.rm'] = 'serif'\n", "mpl.rcParams['font.family'] = 'serif'\n", - "mpl.rcParams['font.serif'] = ['Times New Roman']\n", - "# mpl.rcParams['font.family'] = ['Times New Roman']\n", + "mpl.rcParams[\"font.family\"] = \"serif\"\n", + "mpl.rcParams[\"mathtext.fontset\"] = \"dejavuserif\"\n", + "mpl.rcParams['font.serif'] = 'DejaVu Serif'\n", + "# mpl.rcParams['text.usetex'] = False\n", + "# mpl.rcParams['mathtext.rm'] = 'serif'\n", + "# mpl.rcParams['font.weight'] = 'light'\n", + "# mpl.rcParams['font.family'] = 'serif'\n", + "# mpl.rcParams['font.serif'] = ['Times New Roman']\n", + "# # mpl.rcParams['font.family'] = ['Times New Roman']\n", "mpl.rcParams['axes.titlesize'] = 25\n", "mpl.rcParams['axes.labelsize'] = 20\n", "mpl.rcParams['xtick.labelsize'] = 15\n", @@ -177,7 +184,7 @@ "# plt.hist(truth, log=True, alpha=0.5)\n", "ax.set_ylabel('counts', fontsize=20)\n", "ax.set_xlabel('class', fontsize=20)\n", - "plt.savefig('fig/mock_counts.png')\n", + "# plt.savefig('fig/mock_counts.png')\n", "plt.show()\n", "plt.close()" ] @@ -470,6 +477,22 @@ "plasticc = wrap_up_classifier(cm, 'Mutually Subsuming', plasticc, delta=0.1)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):\n", + " new_cmap = mpl.colors.LinearSegmentedColormap.from_list(\n", + " 'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),\n", + " cmap(np.linspace(minval, maxval, n)))\n", + " return new_cmap\n", + "\n", + "cmap = plt.get_cmap('hot_r')\n", + "fave_cmap = truncate_colormap(cmap, 0.3, 0.9)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -515,8 +538,8 @@ "# print(position)\n", " testname = info_dict['names'][i]\n", " \n", - " im = ax.imshow(info_dict['cm'][testname], vmin=0., vmax=1., cmap='winter_r')\n", - " ax.text(.5,.9,testname,horizontalalignment='center',transform=ax.transAxes, fontsize=16)\n", + " im = ax.imshow(info_dict['cm'][testname], vmin=0., vmax=1., cmap=fave_cmap)\n", + " ax.text(.5,.9,testname,horizontalalignment='center',transform=ax.transAxes, fontsize=14)\n", "# ax.tick_params(axis=u'both', which=u'both',length=10)\n", "# pylab.colorbar()\n", "# fig.subplots_adjust(right=0.5)\n", @@ -527,7 +550,7 @@ " bigAxes.set_ylabel(r'true class', fontsize=20, labelpad=-15)\n", " bigAxes.set_xlabel(r'predicted class', fontsize=20, labelpad=15)\n", " pylab.tight_layout()\n", - " pylab.savefig('fig/all_'+fn+'_cm.png', bbox_inches='tight', pad_inches=0)" + " pylab.savefig('fig/all_'+fn+'_cm.png', bbox_inches='tight', pad_inches=0, dpi=200)" ] }, { @@ -1919,9 +1942,9 @@ "anaconda-cloud": {}, "celltoolbar": "Tags", "kernelspec": { - "display_name": "Python 3", - "language": "python3", - "name": "python3" + "display_name": "proclam (Python 3)", + "language": "python", + "name": "proclam_3" }, "language_info": { "codemirror_mode": { diff --git a/paper/main.tex b/paper/main.tex index 69d02b8..980f41d 100644 --- a/paper/main.tex +++ b/paper/main.tex @@ -94,10 +94,21 @@ \subsection*{Acknowledgments} % Standard papers only: A.B.C. acknowledges support from grant 1234 from ... This paper has undergone internal review in the LSST Dark Energy Science Collaboration. % REQUIRED if true -The authors would like to thank Melissa Graham, Weikang Lin, and Chad Schafer for serving as the LSST-DESC publication review committee. -The authors further wish to thank Tom Loredo for helpful feedback provided in the preparation of this paper. - -AIM is advised by David W. Hogg and was supported by National Science Foundation grant AST-1517237. +The authors would like to thank Melissa Graham, Weikang Lin, and Chad Schafer for serving as the LSST-DESC publication review committee, as well as Tom Loredo for other helpful feedback. +The authors also express gratitude to the anonymous referee for substantive suggestions that improved the paper. + +\changes{ +\software{ +jupyter \citep{kluyver_jupyter_2016}, +matplotlib \citep{hunter_matplotlib:_2007}, +numpy \citep{oliphant_guide_2006, oliphant_python_2007, walt_numpy_2011}, +proclam \citep{malz_proclam_2018}, +scikit-learn \citep{pedregosa_scikit-learn:_2011}, +scipy \citep{jones_scipy:_2001, buitinck_api_2013} +} +} + +AIM was advised by David W. Hogg and was supported by National Science Foundation grant AST-1517237. TA is supported in part by STFC. RB and CS are supported by the Swedish Research Council (VR) through the Oskar Klein Centre. Their work was further supported by the research environment grant ``Gravitational Radiation and Electromagnetic Astrophysical Transients (GREAT)'' funded by the Swedish Research council (VR) under Dnr 2016-06012. diff --git a/paper/tex/conclusions.tex b/paper/tex/conclusions.tex index a25ba76..97a6090 100644 --- a/paper/tex/conclusions.tex +++ b/paper/tex/conclusions.tex @@ -1,22 +1,33 @@ \section{Conclusion} \label{sec:conclusion} -As part of the preparation for \plasticc\, we investigate the properties of metrics suitable for probabilistic light curve classifications in the absence of a single scientific goal. -To that end, we sought a metric that avoids reducing classification probabilities to deterministic labels and one that rewards a classifier with strong performance across all classes over a classifier that performs well on a small subset of the classes and poorly on all others. +% intro and data +As part of the preparation for \plasticc\, we investigated the properties of metrics suitable for probabilistic light curve classifications in the absence of a single scientific goal. +Therefore, we sought a metric that avoids reducing classification probabilities to deterministic labels \changes{and is compatible with a multi-class, rather than binary (two-class), setting. +We did not consider some of the most popular metrics used in astronomy (such as accuracy, combinations of the true and false positive and negative rates, and AUC functions thereof) because they did not satisfy these criteria, even though it is in principle possible to extend such metrics for our situation.} +% In line with the goals of \plasticc, an important desideratum was to have a metric that tends to} reward a classifier's performance across all classes over a classifier that performs well on a small subset of the classes and poorly on others. +\changes{Our experimental design thus explores the response of potential metrics to simulated classification submissions from a set of mock classifier archetypes expected of generic transient and variable classifiers.} -We compared two metrics specific to probabilistic classifications: the Brier score and the log-loss. -Even though the Brier score and log-loss metrics take values consistent with one another, they are structurally and conceptually different, with wholly different interpretations. -The Brier score is a sum of square differences between probabilities; the explicit penalty term is an attractive feature, but it treats probabilities as generic scores and is not interpretable in terms of information. -The log-loss on the other hand is readily interpretable, meaning the metric itself could be propagated into forecasting the constraining power of \lsst, affecting the choice of observing strategy. +% the metrics +\changes{We identified two metrics of multi-class classification probabilities established in the literature: the Brier score and the log-loss.} +The Brier score and the log-loss metrics are structurally and conceptually different, with wholly different interpretations. +The Brier score is a sum of square differences between probabilities; +the explicit penalty term is an attractive feature, but it treats probabilities as generic scores. +The log-loss on the other hand is readily interpretable \changes{as a measure of information}, meaning the metric itself could be propagated into forecasting the cosmological constraining power of \lsst, affecting the choice of observing strategy. -We discovered that the log-loss is somewhat more sensitive to the systematic errors in classification that we find most concerning for generic scientific applications. -While both metrics could be appropriate for \plasticc, the log-loss is preferable due to its interpretability in terms of information. +% weights +\changes{When evaluated with equal weight on each classified object,} both the Brier score and the log-loss metrics are susceptible to rewarding a classifier that performs well on the most prevalent class and poorly on all others, which fails to meet the needs of \plasticc's diverse motivations \changes{under the unavoidable population imbalances of astronomical data}. +\changes{To discourage competitors from neglecting rare classes,} we explored a weighted average of the metric values on a per-class basis as a possible mitigation strategy to incentivize classifying uncommon classes, effectively ``leveling the playing field'' in the presence of highly nonuniform class membership. +%\changes{%Such weights were taken to be the same for all objects in the same class. -Both metrics are susceptible to rewarding a classifier that performs well on the most prevalent class and poorly on all others, which fails to meet the needs of \plasticc's diverse motivations. -We explored a weighted average of the metric values on a per-class basis as a possible mitigation strategy to incentivize classifying uncommon classes, effectively ``leveling the playing field'' in the presence of highly imbalanced class membership. -Although weights do impact the interpretability of the log-loss, we select a per-class weighted log-loss as the optimal choice for \plasticc. +% findings +On the basis of the mock classifier rankings, we found that both metrics reward the classifiers that are better and penalize those that are worse, where better and worse are defined by our common intuition, yielding the same rankings under either metric and demonstrating that both could be appropriate for \plasticc. +However, since only one could be selected, the log-loss was chosen due to its potential for interpretation after the conclusion of the challenge. +\changes{While modifyinging the log-loss metric to handle weights for different classes diminishes its interpretability, it can still be understood as information gain, subject to the value we as scientists place on knowledge of each class.} -% We note that in order to map on to the Kaggle evaluation platform, a metric weighted only by class was used for the general challenge, while a log-loss with more complicated weighting procedure will be used for the science competition (which will continue for an additional month after the main Kaggle release). +% justifying limited choice of metrics to consider +\changes{The space of possible metrics we could have considered is truly unbounded, from traditional metrics of deterministic labels to established extensions thereof for probabilistic classifications to novel quantities tuned to any given science case. +Though there was no need to do a more extensive survey of metrics nor to devise new metrics for \plasticc, since both log-loss and Brier score passed the basic sanity tests for this application, further work remains to be done in optimally selecting probabilistic classification metrics in other astronomical contexts.} We conclude by noting that care should be taken in planning future open challenges to ensure alignment between the challenge goals and the performance metric, so that efforts are best directed to achieve the challenge objectives. -We hope that this study of metric performance across a range of systematic effects and weights may serve as a guide to approaching the problem of identifying optimal probabilistic classifiers for general science applications. +It is our hope hope that this study of metric performance across a range of systematic effects and weights may serve as a guide to approaching the problem of identifying promising probabilistic classifiers for general science applications. diff --git a/paper/tex/data.tex b/paper/tex/data.tex index 0aaac29..f119014 100644 --- a/paper/tex/data.tex +++ b/paper/tex/data.tex @@ -16,8 +16,8 @@ \section{Data} \begin{figure} \begin{center} \includegraphics[width=0.49\textwidth]{./fig/complete_counts.png} - \caption{The number of objects in a given class as a function of class population size. - The true class populations are logarithmically distributed.} + \caption{\changes{The number of members of each of thirteen mock classes considered in this work. + Class populations were simulated by drawing the number of members of a given class from a logarithmic distribution to emulate the extreme class imbalances typical of astronomical samples.}} \label{fig:classdist} \end{center} \end{figure} @@ -56,7 +56,7 @@ \subsection{Mock classification schemes} % Without loss of generality, we can decompose $\mathbb{C}$ under some basis functions of parameters $\mathcal{C}$, the same as those introduced above in the definition of a classification posterior $p(m \mid d_{n}, D, \mathcal{C})$. % The CPM $\mathbb{C}$ thus defines the behavior of a classifier. -Assuming the light curves contain information about the true class (an assumption that underlies classification as a whole), we can use the appropriate row $\mathbb{C}_{m'_{n}} = p(\hat{m} \mid m', \mathcal{C})$ of the CPM $\mathbb{C}$ as a proxy for $p(m \mid d_{n}, D, \mathcal{C})$, without directly classifying light curves themselves.\footnote{This assumption is key to the generality of this work, which was condicted without any knowledge of the \plasticc\ dataset simulation procedure.} +Assuming the light curves contain information about the true class (an assumption that underlies classification as a whole), we can use the appropriate row $\mathbb{C}_{m'_{n}} = p(\hat{m} \mid m', \mathcal{C})$ of the CPM $\mathbb{C}$ as a proxy for $p(m \mid d_{n}, D, \mathcal{C})$, without directly classifying light curves themselves.\footnote{This assumption is key to the generality of this work, which was conducted without any knowledge of the \plasticc\ dataset simulation procedure.} To emulate the effect of natural variation of information content in different light curves (e.g. a noisy lightcurve has less information to recover than one with a higher signal-to-noise ratio) using the above, we generate a posterior probability vector $\vec{p}(m \mid m', \mathbb{C})$ by taking a Dirichlet-distributed draw \begin{eqnarray} \label{eq:cmtoprob} @@ -75,7 +75,7 @@ \subsection{Mock classification schemes} \begin{figure*} \begin{center} \includegraphics[width=0.8\textwidth]{./fig/all_sim_cm.png} - \caption{Conditional probability matrices for eight mock classifiers. + \caption{Conditional probability matrices \changes{(CPMs)} for eight mock classifiers. Top row: the uncertain classifier's uniform CPM; the perfect classifier's identity CPM; @@ -108,10 +108,10 @@ \subsection{Mock classification schemes} \begin{center} \includegraphics[width=0.49\textwidth]{./fig/combined.png}\\ \includegraphics[width=0.49\textwidth]{./fig/examples.png} - \caption{A realistically complex conditional probability matrix and classification posteriors drawn from it. + \caption{A realistically complex conditional probability matrix \changes{(CPM)} and classification posteriors drawn from it. Top: An example of a realistically complex conditional probability matrix, constructed by selecting a systematic for each individual class. This illustrates (for example), how a classifier may exhibit multiple systematics from Figure~\ref{fig:mock_cm} for each true class. - Bottom: Example classification probabilities, drawn from the above CPM, with their true class indicated by a red star and the systematic, characterized by its row in the CPM, affecting that true class described on the right. + Bottom: Example classification probabilities, drawn from the above CPM, with their true class indicated by a star and the systematic, characterized by its row in the CPM, affecting that true class described on the right. The Dirichlet process emulates the variation in classification posteriors due to differences between light curves within a given class, leading to different classification posteriors even among rows sharing a true class. } \label{fig:mock_probs} @@ -121,7 +121,7 @@ \subsection{Mock classification schemes} An actual classifier is expected to be more complex than the simplified cases of Figure~\ref{fig:mock_cm}, with different systematic behavior for each class. An example of a combined CPM across different classes and systematics is given in the top panel of Figure~\ref{fig:mock_probs}. The rows of this CPM correspond to rows of the archetypical classifiers of Figure~\ref{fig:mock_cm}. -To demonstrate the procedure by which mock classification posteriors are generated from rows of the CPM, we provide 22 examples of draws of the posterior CPM in the bottom panel of Figure~\ref{fig:mock_probs}. +To demonstrate the procedure by which mock classification posteriors are generated from rows of the CPM, we provide 26 examples of draws of the posterior CPM in the bottom panel of Figure~\ref{fig:mock_probs}. Given a set of true class identities, the mock classification posteriors of the bottom panel are Dirichlet draws from the corresponding row of the CPM of the top panel. \subsubsection{Uncertain classification} @@ -191,7 +191,7 @@ \subsection{Realistic classifications} \begin{figure*} \begin{center} \includegraphics[width=\textwidth]{./fig/all_snphotcc_cm.png} - \caption{Conditional probability matrices of the \citet{lochner_photometric_2016} methods applied to the second post-challenge release of the \snphotcc\ dataset. + \caption{Conditional probability matrices \changes{(CPMs)} of the \citet{lochner_photometric_2016} methods applied to the second post-challenge release of the \snphotcc\ dataset. Columns: the five machine learning methods of Boosted Decision Tree (BDT), K-Nearest Neighbors (KNN), Naive Bayes (NB), Neural Network (NN), and Support Vector Machine (SVM). Top row: five machine learning methods applied to template decompositions as features. Bottom row: the same five machine learning methods applied to wavelet decompositions as features. diff --git a/paper/tex/discussion.tex b/paper/tex/discussion.tex index 7c51513..a09e090 100644 --- a/paper/tex/discussion.tex +++ b/paper/tex/discussion.tex @@ -4,14 +4,16 @@ \section{Discussion} The goal of this work is to identify the metric most suited to \plasticc, which seeks classification posteriors of complete light curves similar to those anticipated from \lsst, with an emphasis on classification over all types, rewarding a ``best in show'' classifier rather than focusing on any one class or scientific application.\footnote{At the conclusion of \plasticc, other metrics specific to scientific uses of one or more particular classes will be used to identify ``best in class'' classification procedures that will be useful for more targeted science cases.} The weighted log-loss is thus the metric most suited to the current \plasticc\ release. -Future releases of \plasticc\ will focus on different challenges in transient and variable object classification, with metrics appropriate to identifying methodologies that best enable those goals. -We discuss approaches to identifying optimal metrics for these variations, which may be developed further in future work. +% \sout{Future releases of \plasticc\ will focus on different challenges in transient and variable object classification, with metrics appropriate to identifying methodologies that best enable those goals. +% We discuss approaches to identifying optimal metrics for these variations, which may be developed further in future work.} +\changes{Because transient and variable object classification is crucial for a variety of scientific objectives, the impact of a shared performance metric on this diversity of goals leads to complex and covariant trade-offs. +Though the selection criteria for metrics specific to each science goal are outside the scope of this work, which concerns only the first instantiation of \plasticc, we discuss below some issues concerning the identification of metrics for a few example science cases.} -\subsection{Early classification} +\subsection{\changes{Ongoing transient follow-up}} \label{sec:early} Spectroscopic follow-up is only expected of a small fraction of \lsst's detected transients and variable objects due to limited resources for such observations. -In addition to optical spectroscopic follow-up, photometric observations in other wavelength bands (near infrared and x-ray from space; microwave and radio from the ground) will be key to building a physical understanding of the object, particularly as we enter the era of multi-messenger astronomy with the added possibility of optical gravitational wave signatures. +In addition to optical spectroscopic follow-up, photometric observations in other wavelength bands (near infrared and x-ray from space; microwave and radio from the ground) \changes{or at different times} will be key to building a physical understanding of the object, particularly as we enter the era of multi-messenger astronomy with the added possibility of optical gravitational wave signatures. Prompt follow-up observations are highly informative for fitting models to the light curves of familiar source classes and to characterizing anomalous light curves that could indicate never-before-seen classes that have eluded identification due to rarity or faintness. As such, decisions about follow-up resource allocation must be made quickly and under the constraint that resources wasted on a misclassification consume the budget remaining for future follow-up attempts. A future version of \plasticc\ focused on early light curve classification should have a metric that accounts for these limitations and rewards classifiers that perform better even when fewer observations of the lightcurve are available. @@ -24,6 +26,9 @@ \subsection{Early classification} The critical question for choosing the most appropriate metric for any specific science goal motivating follow-up observations is to maximize information. We provide two examples of the kind of information one must maximize via early light curve classification and the qualities of a deterministic metric that might enable it. +\subsection{\changes{Spectroscopic supernova cosmology}} +\label{sec:spec_sncosmo} + Supernova cosmology with spectroscopically confirmed light curves benefits from true positives, which contribute to the constraining power of the analysis by including one more data point; when the class in which one is interested is as plentiful as SN Ia and our resources limited a priori, we may not be concerned by a high rate of false negatives. % requires making a decision balancing the improved constraining power of including another SN Ia in the analysis, thereby constraining the cosmological parameters, so only true positives contribute information, and if we had a perfect classifier and standard follow-up spectroscopy resources, there would be a maximum amount of information about the cosmological parameters that could be gained in this way. @@ -31,12 +36,20 @@ \subsection{Early classification} False positives, on the other hand, may not enter the cosmology analysis, but they consume follow-up resources, thereby depriving the endeavor of the constraining power due to a single SN Ia. A perfect classifier would lead to a maximum amount of information about the cosmological parameters conditioned on the follow-up resource budget. -For this scientific application, the metric must be chosen to not only maximize true positives but also to minimize false positives, and their relative impacts on the cosmological constraints can be quantified in terms of the information one would have about the cosmological parameters under different balances of true and false positives. -% balance the value of the information forgone by a false positive and the value of information forgone by a false negative, and the value placed on these is effectively weighted by the value we as researchers place on follow-up resources. -% \aim{Ciite some deterministic metrics relating to TP/FP?} +% \sout{For this scientific application, the metric must be chosen to balance the value of the information forgone by a false positive and the value of information forgone by a false negative, and the value placed on these is effectively weighted by the value we as researchers place on follow-up resources.} +\changes{Consider deterministic labels derived from cutoffs in probabilistic classifications for this scientific application; raising the probability cutoff reduces the number of false positives, boosting the cosmological constraining power, but at a cost of increasing the number of false negatives, which represent constraining power forgone. +As this tradeoff is asymmetric, it is insufficient to consider only the true and false positive and negative rates, as the \snphotcc\ FoM does, without propagating their impact on the information gained about the cosmological parameters.} +% \aim{Cite some deterministic metrics relating to TP/FP?} + +\subsection{\changes{Anomalous transient and variable detection}} +\label{sec:anom} -Anomaly detection also gains information only from true positives, but the cost function is different in that the potential gain of information from a true positive, since there is no information about undiscovered classes ahead of time. -An example would be the recent detection of a kilonova, flagged initially by the detection of gravitational waves from an object. +\changes{A particularly exciting science case is anomaly detection, the discovery of entirely unknown classes of transient or variable astrophysical sources, or distinguishing some of the rarest types of sources from more abundant types. +Like the case of spectroscopic supernova cosmology discussed above,} anomaly detection also gains information only from true positives, but the cost function is different in that the potential information gain is unbounded when there is no prior information about undiscovered classes. +% \aim{COMMENT RB: not to stay in doc, but I don't understand the prev sentence. I would also object to the recent detection of kilonova as a good example of anomaly detection, I can buy it if I squint very hard +% COMMENT AIM: Agreed, but I couldn't think of a better one at the time of writing.} +% \sout{An example would be the recent detection of a kilonova, flagged initially by the detection of gravitational waves from an object.} +\changes{The discovery of pulsars serves as an example of novelty detection enabled by a human classifier \citep{hewish_observation_1968, bell_burnell_measurement_1969}.} Resource availability for identifying new classes is more flexible, increasing when new predictions or promising preliminary observations attract attention, and decreasing when a discovery is confirmed and the new class is established. In this way, a false positive does not necessarily consume a resource that could otherwise be dedicated to a true positive, and the potential information gain is sufficiently great that additional resources would likely be allocated to observe the potential object. @@ -44,7 +57,7 @@ \subsection{Early classification} % For a rare event like a kilonova, a false negative represents an unbounfalse positive does not appreciably reduce the amount of remaining information available to collect, but a false negative represents a large quantity of information forgone. % Furthermore, r % In this case, the information forgone by a false negative is significant compared to the information forgone by a false positive. -Thus, a metric tuned to anomaly detection would aim to minimize the false negative rate and maximize the true positive rate. +Thus, a metric \changes{for evaluating} anomaly detection \changes{effectiveness} would aim to \changes{minimize the false negative rate and maximize the true positive rate.} % \aim{Cite some deterministic metrics relating to TP/FN?} % \subsection{Hierarchical classes} diff --git a/paper/tex/introduction.tex b/paper/tex/introduction.tex index 950ac13..46bee8a 100644 --- a/paper/tex/introduction.tex +++ b/paper/tex/introduction.tex @@ -13,9 +13,11 @@ \section{Introduction} % Thus several science cases (such as SN cosmology) will actively depend om classification of astrophysical sources based on the photometric , and possibly a much smaller training sample/model based on a spectroscopic sub-sample. As such, there is an acute need for classifiers of photometric light curves that can perform well on datasets that include a wide variety of sources including those that are at the limits of detection. -The Photometric \lsst\ Astronomical Time-series Classification Challenge (\plasticc\footnote{\url{http://plasticcblog.wordpress.com/}}) aims to identify and motivate the development of classification techniques that serve astronomical science goals by engaging the broader community outside astronomy. -\plasticc's dataset is comprehensive, including models for well-understood classes, newly observed classes, and classes that have only been proposed to exist, to simulate serendipitous discoveries anticipated of \lsst. -Additionally, \plasticc\ will join the ranks of a handful of past astronomy classification challenges including \citep[Mapping Dark Matter\footnote{\url{https://www.kaggle.com/c/mdm}}]{kitching_gravitational_2011}, \citep[Observing Dark Worlds\footnote{\url{https://www.kaggle.com/c/DarkWorlds}}]{harvey_observing_2013}, and \citep[the Galaxy Challenge\footnote{\url{https://www.kaggle.com/c/galaxy-zoo-the-galaxy-challenge}}]{dieleman_rotation-invariant_2015}, all hosted on Kaggle\footnote{\url{https://www.kaggle.com/}}, a platform that hosts data analytics competitions where seasoned professionals and amateurs alike can compete to classify, model, and predict large data sets uploaded by companies or scientific collaborations. +The Photometric \lsst\ Astronomical Time-series Classification Challenge (\plasticc\footnote{\url{http://plasticcblog.wordpress.com/}, \url{https://www.kaggle.com/c/PLAsTiCC-2018}}) aimed\footnote{\changes{\plasticc\ was run as a Kaggle challenge from 17 September 2018 to 17 December 2018. +Though \plasticc\ concluded prior to the final revision of this paper, the study herein was conducted entirely before the commencement of \plasticc, and the draft was submitted to the journal prior to \plasticc's conclusion, hence the use of the present and future tenses throughout this paper.}} +to identify and motivate the development of classification techniques that serve astronomical science goals by engaging the broader community outside astronomy. +\plasticc's dataset is comprehensive, including models for well-understood classes, newly observed classes, and classes that have only been proposed to exist, to simulate serendipitous discoveries anticipated of \lsst \citep{the_plasticc_team_photometric_2018, kessler_models_2019}. +Additionally, \plasticc\ joins the ranks of a handful of past astronomy classification challenges including \citep[Mapping Dark Matter\footnote{\url{https://www.kaggle.com/c/mdm}}]{kitching_gravitational_2011}, \citep[Observing Dark Worlds\footnote{\url{https://www.kaggle.com/c/DarkWorlds}}]{harvey_observing_2013}, and \citep[the Galaxy Challenge\footnote{\url{https://www.kaggle.com/c/galaxy-zoo-the-galaxy-challenge}}]{dieleman_rotation-invariant_2015}, all hosted on Kaggle\footnote{\url{https://www.kaggle.com/}}, a platform that hosts data analytics competitions where seasoned professionals and amateurs alike can compete to classify, model, and predict large data sets uploaded by companies or scientific collaborations. Kaggle attracts a broad userbase, and those without domain knowledge may provide novel approaches to the problem at hand. Classification in astronomy may proceed through images, as has been done in the contexts of galaxy classification \citep{hoyle_measuring_2016}, supernova classification \citep{cabrera-vives_deep-hits:_2017}, identification of bars in galaxies \citep{abraham_detection_2018}, weak lensing estimation\footnote{\url{http://great3challenge.info/}}\citep{mandelbaum_third_2014}, separation of Near Earth Asteroids from artifacts in images \citep{morii_machine-learning_2016}, as well as time-domain classification \citep{morii_machine-learning_2016, mahabal_deep-learnt_2017, zevin_gravity_2017}, and even noise classification \citep{zevin_gravity_2017, george_classification_2018}. @@ -75,7 +77,7 @@ \section{Introduction} %\item The metric must be reliable, giving consistent results for different instantiations of the same test case. %\end{itemize} -We perform a systematic exploration of the sensitivity of metrics of probabilistic classification to anticipated classifier failure modes using the PRObabilistic CLAssification Metric (\proclam) code, which is is publicly available on GitHub\footnote{\url{https://github.com/aimalz/proclam}}. +We perform a systematic exploration of the sensitivity of metrics of probabilistic classification to anticipated classifier failure modes using the PRObabilistic CLAssification Metric (\proclam) code \citep{malz_proclam_2018}, which is publicly available on GitHub\footnote{\url{https://github.com/aimalz/proclam}}. The mock classification submissions that we use for this study are described in Section~\ref{sec:data}. The metrics we consider are presented in Section~\ref{sec:methods}. The behavior of the metrics as a function of mock classification results is presented in Section~\ref{sec:results}. diff --git a/paper/tex/methods.tex b/paper/tex/methods.tex index b6a8911..ca66e48 100644 --- a/paper/tex/methods.tex +++ b/paper/tex/methods.tex @@ -128,12 +128,12 @@ \subsection{Weights} For example, requiring a minimum difference in probability density between the maximum probability class and the next highest probability class would help avert this degeneracy. % (e.g. a newly discovered supernova with a very small number of points may be indistinguishable from a Cataclysmic variable going through a brightening). -A simpler alternative that we investigate in this paper is to use a weighted average +\changes{A simpler alternative that we investigate in this paper is to use a weighted average \begin{eqnarray} \label{eq:weightavg} - Q_{m} &=& \frac{1}{\sum_{n} w_{n}} \sum_{n=1}^{N} w_{n} \sum_{m=1}^{M} Q_{n, m} + Q &=& \frac{1}{\sum_{m} w_{m}} \sum_{m} w_{m} Q_{m} \end{eqnarray} -of per-class metrics. +of per-class metrics $Q_{m}$.} (While weights could be assigned to each term $Q_{n, m}$, we do not consider this complexity at this time.) Weights that are not proportional to $N^{-1}$ nor $M^{-1}$ may be chosen to encourage challenge participants to direct more attention to classes with less active classification efforts or those that have been historically more difficult to classify due to observational limitations. @@ -141,4 +141,6 @@ \subsection{Weights} The weights for the \plasticc\ metric, however, must be determined before there is knowledge of which systematics affect which classes. Because of this caveat, the choice of weights is isolated to an inherently human problem dictated by the value placed on the scientific merits of knowledge of each class. This paper, on the other hand, can only quantify the impact of weights in relation to the systematics. -We thus agnostically test weighting schemes where classes affected by a particular systematic take a given weight $0 \leq w \leq 1$ and all other classes have a weight $(1 - w) / (M - 1)$. +We thus agnostically test weighting schemes\footnote{\changes{The weights considered in this study are more extreme than those ultimately used for \plasticc\ because the true weights were blinded from some authors prior to the end of the challenge. +However, we note that the weights could be (and in fact were) discovered by \plasticc\ competitors by systematically probing the output of the public leader board with entries from the cruise control classifier archetype targeting each class one at a time.}} +where classes affected by a particular systematic take a given weight $0 \leq w \leq 1$ and all other classes have a weight $(1 - w) / (M - 1)$. diff --git a/paper/tex/results.tex b/paper/tex/results.tex index f7e6921..eab81d2 100644 --- a/paper/tex/results.tex +++ b/paper/tex/results.tex @@ -65,10 +65,8 @@ \subsection{Mock classifier systematics} Subsumed from Almost & 0.641 & 1.629\\ Subsumed from Perfect & 1.0 & 18.421\footnote{The entry for the log-loss of a classifier that subsumes a class into one that is otherwise perfectly classified should be infinite but is bounded by the numerical precision of our calculations.} \end{tabular} -\caption{ -The value of each metric when the weight is entirely on the class with the indicated characteristic. -Weighting changes the metric performance: the value of each metric when the weight is entirely on the class with the indicated characteristic (correponsding to a $w=1$ case in Figure~\ref{fig:all_combined}). -The log-loss is more sensitive than te Brier score, with larger values of the score (indicating poor classification performance), particularly for the subsuming systematic. +\caption{\changes{Metric values computed using Equation~\ref{eq:weightavg} with all weight on the mock class affected by the indicated systematic, described in Sec.~\ref{sec:mockdata}, corresponding to the $w=1$ cases in Figure~\ref{fig:all_combined}. +While the log-loss metric has a larger dynamic range than the Brier score for poor classification, the archetypical classifiers would be ranked (lower values are better) the same way by either metric.} } \label{tab:extents} \end{table} @@ -89,8 +87,8 @@ \subsection{Mock classifier systematics} \end{tabular} \caption{ The slopes for each baseline-plus-systematic pair in the space of log-loss versus Brier score. -A higher slope corresponds to increased sensitivity of the log-loss over the Brier score. -The contrast between log-loss and Brier score is highest on a baseline of the perfect classifier, meaning the log-loss may be more appropriate for discriminating between classifiers that are already extremely good. +A higher slope corresponds to increased sensitivity of the log-loss over the Brier score \changes{to the systematic-baseline pair in question}. +The contrast between log-loss and Brier score is highest on a baseline of the perfect classifier, meaning the log-loss may \changes{more strongly discriminate} between classifiers that are already extremely good. } \label{tab:slopes} \end{table} @@ -144,50 +142,62 @@ \subsection{Mock classifier systematics} \subsection{Representative classifications} \label{sec:realresults} -We apply the log-loss and Brier metrics to the classification output from \snmachine. While the classification methods described in \citet{lochner_photometric_2016} refer to the idealized subset of the \snphotcc\ data, these approaches are the state-of-the-art in classification of extragalactic transients. We present in Table~\ref{tab:snmachineresults} the log-loss and Brier scores assuming an equal weight per object. -%, for classification probabilities derived from running the algorithms of \citet{lochner_photometric_2016} on the \snphotcc\ data of Section~\ref{sec:realdata}. -Table~\ref{tab:snmachineresults} also contains the ranking of classifier performance under each metric. +We apply the log-loss and Brier metrics to the classification output from \snmachine. +While the classification methods described in \citet{lochner_photometric_2016} refer to the idealized subset of the \snphotcc\ data, these approaches are the state-of-the-art in classification of extragalactic transients. +We present in \changes{Figure~\ref{fig:snmachineresults} the rankings under the} log-loss and Brier score metrics assuming an equal weight per object. +%, for classification probabilities derived from running the algorithms of \citet{lochner_photometric_2016} on the \snphotcc\ data of Section~\ref{sec:realdata}. -\begin{table*}[] - \begin{centering} -\begin{tabular}{lllllll}%ll} -Rank $R$ & $R_\mathrm{FoM}$ & FoM & %$R_\mathrm{AUC}$ & AUC & -$R_\mathrm{LogLoss}$ & Log-loss & $R_\mathrm{Brier}$ & Brier \\ -\hline -1 & TBDT & 0.635 %& TBDT & 0.982 -& TBDT & 0.0907 & TBDT & 0.0486 \\ -2 & WBDT & 0.591 %& WBDT & 0.978 -& TSVM & 0.113 & TSVM & 0.0583 \\ -3 & TSVM & 0.514 %& TSVM & 0.969 -& TNN & 0.125 & TNN & 0.0650 \\ -4 & WSVM & 0.499 %& WSVM & 0.968 -& WSVM & 0.1316 & WBDT & 0.0689 \\ -5 & TNN & 0.496 %& TNN & 0.954 -& WBDT & 0.1321 & WSVM & 0.0730 \\ -6 & WNN & 0.480 %& WNN & 0.946 -& TKNN & 0.146 & WNN & 0.0750 \\ -7 & TKNN & 0.384 %& TKNN & 0.942 -& WNN & 0.152 & TKNN & 0.0787 \\ -8 & TNB & 0.340 %& WKNN & 0.894 -& WKNN & 0.228 & TNB & 0.105 \\ -9 & WKNN & 0.114 %& TNB & 0.879 -& TNB & 0.251 & WKNN & 0.132 \\ -10 & WNB & 0.0365 %& WNB & 0.850 -& WNB & 0.443 & WNB & 0.178 \\ -\end{tabular} - \caption{ - The values of three metrics for each of ten \snmachine\ classifiers with equal weight per object. - The metrics broadly agree on the ranking of the classifiers, confirming consistency between a conventional metric of classification performance and the metrics of probabilistic classifications presented here. - However, there are some differences with pairwise swapping between the log-loss and Brier rankings and some significant reordering of ranks 2 through 5 with the FoM metric relative to the probabilistic metrics. - } - \label{tab:snmachineresults} - \end{centering} -\end{table*} +% \sout{\begin{table*}[] +% \begin{centering} +% \begin{tabular}{lllllll}%ll} +% Rank $R$ & $R_\mathrm{FoM}$ & FoM & %$R_\mathrm{AUC}$ & AUC & +% $R_\mathrm{LogLoss}$ & Log-loss & $R_\mathrm{Brier}$ & Brier \\ +% \hline +% 1 & TBDT & 0.635 %& TBDT & 0.982 +% & TBDT & 0.0907 & TBDT & 0.0486 \\ +% 2 & WBDT & 0.591 %& WBDT & 0.978 +% & TSVM & 0.113 & TSVM & 0.0583 \\ +% 3 & TSVM & 0.514 %& TSVM & 0.969 +% & TNN & 0.125 & TNN & 0.0650 \\ +% 4 & WSVM & 0.499 %& WSVM & 0.968 +% & WSVM & 0.1316 & WBDT & 0.0689 \\ +% 5 & TNN & 0.496 %& TNN & 0.954 +% & WBDT & 0.1321 & WSVM & 0.0730 \\ +% 6 & WNN & 0.480 %& WNN & 0.946 +% & TKNN & 0.146 & WNN & 0.0750 \\ +% 7 & TKNN & 0.384 %& TKNN & 0.942 +% & WNN & 0.152 & TKNN & 0.0787 \\ +% 8 & TNB & 0.340 %& WKNN & 0.894 +% & WKNN & 0.228 & TNB & 0.105 \\ +% 9 & WKNN & 0.114 %& TNB & 0.879 +% & TNB & 0.251 & WKNN & 0.132 \\ +% 10 & WNB & 0.0365 %& WNB & 0.850 +% & WNB & 0.443 & WNB & 0.178 \\ +% \end{tabular} +% \caption{ +% The values of three metrics for each of ten \snmachine\ classifiers with equal weight per object. +% The metrics broadly agree on the ranking of the classifiers, confirming consistency between a conventional metric of classification performance and the metrics of probabilistic classifications presented here. +% However, there are some differences with pairwise swapping between the log-loss and Brier rankings and some significant reordering of ranks 2 through 5 with the FoM metric relative to the probabilistic metrics. +% } +% \label{tab:snmachineresults} +% \end{centering} +% \end{table*}} + +\begin{figure} + \begin{center} + \includegraphics[width=0.49\textwidth]{./fig/Tables3_option4.png} + \caption{ + \changes{The rankings of each of the five \snmachine\ classification algorithms (Boosted Decision Tree (BDT), K-Nearest Neighbors (KNN), Naive Bayes (NB), Neural Network (NN), and Support Vector Machine (SVM)) on template (T*) and wavelet (W*) features with equal weight per object under the three metrics. + The metrics broadly agree on the ranking of the classifiers, confirming consistency between a conventional metric of classification performance and the metrics of probabilistic classifications presented here. + However, there are some differences with pairwise swapping between the log-loss and Brier rankings and some significant reordering of ranks 2 through 5 with the FoM metric relative to the probabilistic metrics.} + } + \end{center} + \label{fig:snmachineresults} +\end{figure} We apply our metrics to the classification output from \snmachine\ applied to the \snphotcc\ dataset as an example of representative light curves and representative classifiers used in extragalactic astronomy. -We present in Table~\ref{tab:snmachineresults} the log-loss and Brier scores assuming an equal weight per object, as well as the original \snphotcc\ metric described in Section~\ref{sec:deterministic}. -Table~\ref{tab:snmachineresults} also contains the ranking of classifier performance under each metric. +We present in \changes{Figure~\ref{fig:snmachineresults} the rankings of each classifier under the log-loss and Brier scores assuming an equal weight per object, as well as the original \snphotcc\ metric described in Section~\ref{sec:deterministic}.} The Brier score, log-loss, and \snphotcc\ FoM are in agreement as to the first- and last-ranked classifiers. This consensus indicates that both of the potential \plasticc\ metrics are roughly consistent with our intuition about what makes a good classifier, providing an anchor between accepted notions of an appropriate metric and the metrics of probabilistic classifications under consideration here. diff --git a/pipeline_sandbox.ipynb b/pipeline_sandbox.ipynb index 6f2dfa8..f85e160 100644 --- a/pipeline_sandbox.ipynb +++ b/pipeline_sandbox.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -20,8 +20,10 @@ "import random\n", "import numpy as np\n", "import scipy.stats as sct\n", + "import scipy.integrate as spi\n", "import sklearn as skl\n", "from sklearn import metrics\n", + "from pycm import ConfusionMatrix\n", "import pandas as pd\n", "import os\n", "\n", @@ -31,12 +33,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import proclam\n", - "from proclam import *" + "# from proclam import *" ] }, { @@ -55,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -72,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -88,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -96,7 +98,7 @@ "M_classes = 5\n", "N_objects = 1000\n", "names = [''.join(random.sample(string.ascii_lowercase, 2)) for i in range(M_classes)]\n", - "truth = A.simulate(M_classes, N_objects)" + "truth = A.simulate(M_classes, N_objects, base=2)" ] }, { @@ -108,9 +110,32 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 0, 'class')" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZkAAAEKCAYAAADAVygjAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAElpJREFUeJzt3XuwXWV9xvHvQwARgaiAbQXGoDBYpEoEqdWhTvEWpFzGAnJRqjDQzAjqTO2YVrx10FZr6whSaChIVe4qMZRYYJRyK8VcRCRE2og4hEG5FQQ0XH/9Y++M2+M5YR/Oec8+Z5/vZ2ZP9nrXe9b6rdlJnvOud+21UlVIktTCJoMuQJI0vAwZSVIzhowkqRlDRpLUjCEjSWrGkJEkNWPISJKaMWQkSc0YMpKkZjYddAGDtt1229W8efMGXYYkzSgrV668v6q2f7Z+sz5k5s2bx4oVKwZdhiTNKEl+2k8/T5dJkpoxZCRJzRgykqRmDBlJUjOGjCSpGUNGktSMISNJasaQkSQ1M+u/jClJAPMWXT5q+51bHNV2v+vPb7r9sdz59wdMyX4cyUiSmjFkJEnNGDKSpGack5H0W8aanxivYZ3PUP8cyUiSmjFkJEnNGDKSpGYMGUlSM4aMJKkZQ0aS1IwhI0lqxpCRJDUzlF/GTHIIcACwDXB2VV054JIkaVZqNpJJskWS7yX5QZLVST41gW2dk+TeJLeOsm5BktuTrE2yCKCqllTV8cBC4F3P/SgkSRPR8nTZ48B+VfUaYE9gQZLX93ZI8pIkW49o22WUbZ0LLBjZmGQOcDqwP7A7cGSS3Xu6nNxdL0kagGYhUx2Pdhc3675qRLc3AUuSPA8gyfHAaaNs61rgwVF2sw+wtqruqKongAuBg9PxWeDbVbVqco5IkjReTSf+k8xJcjNwL3BVVd3Uu76qLgGuAC5KcjRwLHDYOHaxA3BXz/K6bttJwFuAQ5MsHKO2A5Msfvjhh8exO0nSeDQNmap6uqr2BHYE9kmyxyh9PgesB84ADuoZ/Uxkv6dW1V5VtbCqzhyjz2VVdcLcuXMnujtJ0him5BLmqnoIuJrR51X2BfYALgU+Mc5N3w3s1LO8Y7dNkjQNtLy6bPskL+y+fz7wVuBHI/rMBxYDBwPvA7ZNcso4drMc2DXJzkk2B44Alk5G/ZKkiWs5kvk94Ookt9AJg6uq6t9H9NkSOLyqflxVzwDHAD8duaEkFwA3ArslWZfkOICqego4kc68zhrg4qpa3eyIJEnj0uzLmFV1CzD/WfrcMGL5SeCsUfoduZFtLAOWPccyJUkNeVsZSVIzhowkqRlDRpLUjCEjSWrGkJEkNWPISJKaMWQkSc0YMpKkZgwZSVIzhowkqRlDRpLUjCEjSWrGkJEkNWPISJKaMWQkSc0YMpKkZgwZSVIzhowkqRlDRpLUjCEjSWrGkJEkNWPISJKaMWQkSc0YMpKkZgwZSVIzhowkqRlDRpLUjCEjSWrGkJEkNWPISJKaMWQkSc0YMpKkZjYddAEtJDkEOADYBji7qq4ccEmSNCs1G8kk2SnJ1UluS7I6yQcnsK1zktyb5NZR1i1IcnuStUkWAVTVkqo6HlgIvOu5H4UkaSJajmSeAv6yqlYl2RpYmeSqqrptQ4ckLwF+VVWP9LTtUlVrR2zrXOBLwFd6G5PMAU4H3gqsA5YnWdqzj5O766XnbN6iywey3zu3OKr5PuatP7/5PjS7NRvJVNU9VbWq+/4RYA2ww4hubwKWJHkeQJLjgdNG2da1wIOj7GYfYG1V3VFVTwAXAgen47PAtzfUIEmaelMyJ5NkHjAfuKm3vaouSbIzcFGSS4Bj6YxK+rUDcFfP8jrgD4GTgLcAc7sjozNHqelA4MBddtllHLuTJI1H86vLkmwFfAP4UFX9YuT6qvocsB44Azioqh6d6D6r6tSq2quqFo4WMN0+l1XVCXPnzp3o7iRJY2gaMkk2oxMw51XVN8fosy+wB3Ap8Ilx7uJuYKee5R27bZKkaaDl1WUBzgbWVNU/jdFnPrAYOBh4H7BtklPGsZvlwK5Jdk6yOXAEsHRilUuSJkvLkcwbgfcA+yW5uft6x4g+WwKHV9WPq+oZ4BjgpyM3lOQC4EZgtyTrkhwHUFVPAScCV9C5sODiqlrd7pAkSePRbOK/qq4H8ix9bhix/CRw1ij9jtzINpYBy55jmZKkhrytjCSpGUNGktSMISNJasaQkSQ1Y8hIkpoxZCRJzRgykqRmDBlJUjOGjCSpGUNGktSMISNJasaQkSQ1Y8hIkpoxZCRJzRgykqRmDBlJUjOGjCSpGUNGktSMISNJasaQkSQ101fIJPlgkm3ScXaSVUne1ro4SdLM1u9I5tiq+gXwNuBFwHuAv29WlSRpKPQbMun++Q7gq1W1uqdNkqRR9RsyK5NcSSdkrkiyNfBMu7IkScNg0z77HQfsCdxRVb9Msi3wvnZlSZKGQb8jmauqalVVPQRQVQ8AX2hXliRpGGx0JJNkC2BLYLskL+LX8zDbADs0rk2SNMM92+myvwA+BLwUWMmvQ+YXwJca1iVJGgIbDZmq+iLwxSQnVdVpU1STJGlI9DXxX1WnJXkDMK/3Z6rqK43qkiQNgb5CJslXgVcANwNPd5sLMGQkSWPq9xLmvYHdq6paFiNJGi79XsJ8K/C7LQuRJA2ffkcy2wG3Jfke8PiGxqo6qElVkqSh0G/IfLJlEZKk4dTv1WXXtC5EkjR8+r267BE6V5MBbA5sBjxWVdu0KkySNPP1O5LZesP7JAEOBl7fqihJ0nAY9+OXq2MJ8PYG9UiShki/p8ve2bO4CZ3vzaxvUpEkaWj0e3XZgT3vnwLupHPKTJKkMfU7J+MDyiRJ49bXnEySHZNcmuTe7usbSXZsXZwkaWbrd+L/y8BSOs+VeSlwWbdNkqQx9Rsy21fVl6vqqe7rXGD7hnVJkoZAvyHzQJJ3J5nTfb0beKBlYZKkma/fkDkWOBz4GXAPcCjw3kY1SZKGRL+XMP8t8OdV9X8ASV4MfJ5O+EiSNKp+RzKv3hAwAFX1IDC/TUmSpGHRb8hskuRFGxa6I5l+R0GSpFmq36D4R+DGJJd0lw8DPt2mJEnSsOj3G/9fSbIC2K/b9M6quq1dWZKkYdD3Ka9uqBgskqS+jftW/5Ik9cuQkSQ1Y8hIkpoxZCRJzRgykqRmDBlJUjOGjCSpGUNGktSMISNJasaQkSQ1M1R3Uk5yCHAAsA1wdlVdOeCSJGlWm/YjmSTnJLk3ya0j2hckuT3J2iSLAKpqSVUdDywE3jWIeiVJvzbtQwY4F1jQ25BkDnA6sD+wO3Bkkt17upzcXS9JGqBpHzJVdS3w4IjmfYC1VXVHVT0BXAgcnI7PAt+uqlVTXask6TdN+5AZww7AXT3L67ptJwFvAQ5NsnCsH05yQpIVSVbcd999bSuVpFlsqCb+q+pU4NQ++i0GFgPsvffe1bouSZqtZupI5m5gp57lHbttkqRpZKaGzHJg1yQ7J9kcOAJYOuCaJEkjTPuQSXIBcCOwW5J1SY6rqqeAE4ErgDXAxVW1epB1SpJ+27Sfk6mqI8doXwYsm+JyJEnjMO1HMpKkmcuQkSQ1Y8hIkpoxZCRJzRgykqRmZm3IJDkwyeKHH3540KVI0tCatSFTVZdV1Qlz584ddCmSNLRmbchIktozZCRJzRgykqRmDBlJUjOGjCSpGUNGktSMISNJasaQkSQ1Y8hIkpoxZCRJzczakPHeZZLU3qwNGe9dJkntzdqQkSS1Z8hIkpoxZCRJzRgykqRmDBlJUjOGjCSpGUNGktSMISNJasaQkSQ1Y8hIkpoxZCRJzRgykqRmZm3IeBdmSWpv1oaMd2GWpPZmbchIktozZCRJzRgykqRmDBlJUjOGjCSpGUNGktSMISNJasaQkSQ1Y8hIkpoxZCRJzQxlyCR5eZKzk3x90LVI0mzWNGSSvDDJ15P8KMmaJH/0HLdzTpJ7k9w6yroFSW5PsjbJIoCquqOqjpto/ZKkiWk9kvki8B9V9UrgNcCa3pVJXpJk6xFtu4yynXOBBSMbk8wBTgf2B3YHjkyy++SULkmaqGYhk2Qu8MfA2QBV9URVPTSi25uAJUme1/2Z44HTRm6rqq4FHhxlN/sAa7sjlyeAC4GDJ+8oJEkT0XIkszNwH/DlJN9P8q9JXtDboaouAa4ALkpyNHAscNg49rEDcFfP8jpghyTbJjkTmJ/kr0f7QZ8nI0nttQyZTYHXAmdU1XzgMWDRyE5V9TlgPXAGcFBVPTrRHVfVA1W1sKpeUVV/N0YfnycjSY21DJl1wLqquqm7/HU6ofMbkuwL7AFcCnxinPu4G9ipZ3nHbpskaRpoFjJV9TPgriS7dZveDNzW2yfJfGAxnXmU9wHbJjllHLtZDuyaZOckmwNHAEsnXLwkaVK0vrrsJOC8JLcAewKfGbF+S+DwqvpxVT0DHAP8dORGklwA3AjslmRdkuMAquop4EQ68zprgIuranWzo5EkjcumLTdeVTcDe29k/Q0jlp8Ezhql35Eb2cYyYNkEypQkNTKU3/iXJE0PhowkqRlDRpLUTNM5GQ2feYsun5Tt3LnFUZOynY2Zt/785vuQtHGOZCRJzRgykqRmDBlJUjPOyUzAZM1PPBet5zScz5A0GRzJSJKaMWQkSc0YMpKkZgwZSVIzhowkqZlZGzI+flmS2pu1IePjlyWpvVkbMpKk9gwZSVIzqapB1zBQSe5jlEc+zwDbAfcPuogpNNuOFzzm2WKmHvPLqmr7Z+s060NmpkqyoqrGfLT1sJltxwse82wx7Mfs6TJJUjOGjCSpGUNm5lo86AKm2Gw7XvCYZ4uhPmbnZCRJzTiSkSQ1Y8jMUEkeHXQNmhrD/lkn+WSSDw+6jqmQZF6SWwddx1QyZCRJzRgyM0CSdyf5XpKbk/xLkjnd9i8kWZ3kO0me9UtRM0GSjyW5Pcn1SS5I8uEk/9k91hVJ1iR5XZJvJvnfJKcMuuaJSvJXST7Qff+FJN/tvt8vyXk9/bZLcmOSAwZV62RJ8tEk/5PkemC3bttQf8495iQ5q/tv98okz0/ygSS3JbklyYWDLnAyGTLTXJLfB94FvLGq9gSeBo4GXgCsqKpXAdcAnxhclZMjyeuAPwNeA+wP9H5B7YnuF9bOBL4FvB/YA3hvkm2nutZJdh2wb/f93sBWSTbrtl0LkOR3gMuBj1fV5QOpcpIk2Qs4AtgTeAfwup7Vw/w5b7ArcHr33+5DdP7OLwLmV9WrgYWDLG6ybTroAvSs3gzsBSxPAvB84F7gGeCibp+vAd8cSHWT643At6pqPbA+yWU965Z2//whsLqq7gFIcgewE/DAlFY6uVYCeyXZBngcWEUnbPYFPgBsBnwHeH9VXTOwKifPvsClVfVLgCRLe9YN8+e8wU+q6ubu+5XAPOAW4LwkS4AlgyqsBUcy01+Af6uqPbuv3arqk6P0G/Zr0R/v/vlMz/sNyzP6l6WqehL4CfBe4L/ojGz+BNgFWAM8Rec/o7cPqMSpNLSfc4/e43qaznEdAJwOvJbOL5TDcqyGzAzwHeDQJC8BSPLiJC+j89kd2u1zFHD9gOqbTDcABybZIslWwJ8OuqApdB3wYTqnx66jc8rk+9X5IlsBxwKvTPKRwZU4aa4FDunORWwNHDjoggZsE2Cnqroa+AgwF9hqsCVNnqFJy2FVVbclORm4MskmwJN0zlM/BuzTXXcvnXmbGa2qlndPndwC/JzOKZPZ8ujS64CPAjdW1WNJ1nfbAKiqp5McCSxN8khV/fOgCp2oqlqV5CLgB3T+7i4fcEmDNgf4WpK5dM5cnFpVDw24pknjN/41rSTZqqoeTbIlnd94T6iqVYOuS9Jz40hG083iJLsDW9CZizJgpBnMkYwkqRkn/iVJzRgykqRmDBlJUjOGjDRAs+kOxJqdDBlJUjOGjDSFkhzTvdPuD5J8dcS645Ms7677Rve7QiQ5LMmt3fYNN8x8Vc+duW9Jsusgjkd6Nl7CLE2RJK8CLgXeUFX3J3kxnRtgPlpVn0+ybVU90O17CvDzqjotyQ+BBVV1d5IXVtVDSU4D/ruqzkuyOTCnqn41qGOTxuJIRpo6+wGXVNX9AFX14Ij1eyS5rhsqRwOv6rbfAJyb5Hg6tyABuBH4m+69zF5mwGi6MmSk6eNc4MSq+gPgU3TuekBVLQROpnOr+5XdEc/5wEHAr4BlSfYbTMnSxhky0tT5LnDYhodvdU+X9doauKf7wLKjNzQmeUVV3VRVHwfuA3ZK8nLgjqo6lc7DvV49JUcgjZP3LpOmSFWtTvJp4JokTwPfB+7s6fIx4CY6QXITndAB+IfuxH7oPPrhB3RuCf+eJE8CPwM+MyUHIY2TE/+SpGY8XSZJasaQkSQ1Y8hIkpoxZCRJzRgykqRmDBlJUjOGjCSpGUNGktTM/wPlywg9cu2opQAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ "d = np.diff(np.unique(truth)).min()\n", "left_of_first_bin = truth.min() - float(d)/2\n", @@ -125,7 +150,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -156,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -176,7 +201,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -195,9 +220,43 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.77459445 0.11209536 0.06052605 0.00362192 0.07104932]\n", + " [0.11003839 0.80009249 0.01965666 0.02884268 0.01506941]\n", + " [0.09383029 0.07687496 0.65189991 0.07650355 0.12909478]\n", + " [0.12409745 0.0805176 0.00174395 0.75464134 0.06014014]\n", + " [0.06169615 0.08333079 0.04914953 0.02018145 0.77150318]]\n" + ] + }, + { + "data": { + "text/plain": [ + "Text(0, 0.5, 'true class')" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAARYAAAEMCAYAAAABAJmyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEr5JREFUeJzt3XuQXnV9x/H3ZzeJuRJu0REbCXIVKESzUdo0HS5W5KYooXLTUtRUTQ1tJ4qO3NrRqY7UTJlBS6gdqdyiFCGC3CYQciGWbALkwkUqgdqWEkzkkoQku8m3f5yT+mTZhfMkv3POPruf18zOc57znOf8PtlsPvmd8zz7HEUEZmYptdUdwMwGHheLmSXnYjGz5FwsZpaci8XMknOxmFlyLpYGkjbWnWEw6O/fZ0lXSppVd44iJE2QtLruHD25WMwsuUFbLJIukPSIpMckXSupPV8/W9IaSfMljas7J4CkyyQ9LWmxpJslzZK0IM/aKelJSZMl3SbpGUnfqDnvlyXNzJdnS3ogXz5R0o0N2+0vaamk0+rK2pDl65J+KWkxcHi+rt9+j3tol3Rd/nN7n6QRkmZKekLSSkm3VB1oUBaLpPcCnwSmRMREYDtwPjAK6IyIo4CHgCvqS5mRNBk4CzgWOAXoaHh4W0R0AP8E3AHMAI4GLpS0X9VZGywCpubLHcBoSUPzdQsBJL0DuAu4PCLuqiVlTtIk4BxgInAqMLnh4f76PW50KHBN/nP7MtnPy1eB90XEMcDnqw40pOoB+4mTgEnAMkkAI4B1wA5gbr7NDcBttaTb1RTgjojYAmyR9LOGx+blt6uANRHxAoCkZ4HxwPpKk/7OcmCSpL2ArcAKsoKZCswEhgLzgRkR8VBNGRtNBX4aEZsBJM1reKy/fo8brY2Ix/Ll5cAEYCVwo6TbgdurDjQoZyyAgOsjYmL+dXhEXNnLdv39F6m25rc7GpZ33q/tP42I6ALWAhcCD5PNYE4ADgGeBLrJ/gGcXFPEZvTL73EPjbm2k+U6DbgGeD/Zf6CVZh2sxTIfmCbp7QCS9pV0INn3Y1q+zXnA4pryNVoCnCFpuKTRwOl1BypoETCL7NBnEdl0/NHIfus1gIuAIyRdUl/E/7cQODM/NzEGOKPuQHuoDRgfEQ8ClwBjgdFVBugvjVupiHhC0qXAfZLagC6yY+dNwAfyx9aRnYepVUQsy6fmK4EXyabkr9SbqpBFwNeBpRGxSdKWfB0AEbFd0rnAPEmvRcT36goaESskzQUeJ/t7X1ZXlkTagRskjSWbnV8dES9XGUD+2IT+T9LoiNgoaSTZ/67TI2JF3bnM+jIoZywtaI6kI4HhZOeGXCrWr3nGYmbJDdaTt2ZWIheLmSXnYjGz5FwsDSRNrztDM1otL7Re5lbLC/0js4tlV7X/hTSp1fJC62VutbzQDzK7WMwsuZZ7uXm/fdti/Phy3n6zfv0O9tsvfdeufWb/5PsE2Na9mWFDRqbfcXd3+n3mtu14nWFtI5LvN7rKydzFVobytlL2rWFDS9nvtu2vM6w9/ff49e5X2bb9dRXZtuXeIDd+/BDu+3k5/1DL8qnTP1t3hKa0rftt3RGa1v3C/9YdoWlDDhhfd4SmPPw/N771RjkfCplZci4WM0vOxWJmyblYzCw5F4uZJediMbPkXCxmlpyLxcySc7GYWXIuFjNLzsViZsm5WMwsOReLmSXnYjGz5FwsZpaci8XMknOxmFlytRaLpI11jm9m5fCMxcySq6xYJF0g6RFJj0m6VlJ7vn62pDWS5ksaV1UeMytPJcUi6b3AJ4EpETER2A6cD4wCOiPiKOAh4Io+nj9dUqekzvXrd1QR2cz2QFWf0n8SMAlYJglgBLAO2AHMzbe5AbittydHxBxgDsDEY4e11vVKzAahqopFwPUR8bVdVkqX9djOpWE2AFR1jmU+ME3S2wEk7SvpwHz8afk25wGLK8pjZiWqZMYSEU9IuhS4T1Ib0AXMADYBH8gfW0d2HsbMWlxlV0KMiLn87nzKTqOrGt/MquP3sZhZci4WM0vOxWJmyblYzCw5F4uZJediMbPkXCxmlpyLxcySc7GYWXIuFjNLzsViZsm5WMwsOReLmSXnYjGz5FwsZpaci8XMklNEa33M7NiRB8Rxh32m7hhNufueW+qO0JSPHPTBuiM0TUMq+8yyZGJbV90RmvKLrnt4dcd6FdnWMxYzS87FYmbJuVjMLDkXi5kl52Ixs+RcLGaWnIvFzJJzsZhZci4WM0vOxWJmyblYzCw5F4uZJediMbPkXCxmlpyLxcySc7GYWXIuFjNLzsViZsm5WMwsOReLmSVX6icQS7oMuAB4Cfg1sBw4HXgUmAqMAj4NfA34fWBuRFxaZiYzK19pxSJpMnAWcCwwFFhBViwA2yKiQ9LFwB3AJGAD8CtJsyNifVm5zKx8ZR4KTQHuiIgtEfEa8LOGx+blt6uANRHxQkRsBZ4FxvfckaTpkjoldW7r3lRiZDNLoa5zLFvz2x0Nyzvvv2EWFRFzIqIjIjqGDRlVRT4z2wNlFssS4AxJwyWNJju3YmaDQGnnWCJimaR5wErgRbLDnlfKGs/M+o+yr0t5VURcKWkksBBYHhHX7XwwIhYACxruH19yHjOrQNnFMkfSkcBw4PqIWFHyeGbWD5RaLBFxXpn7N7P+ye+8NbPkXCxmlpyLxcySc7GYWXIuFjNLzsViZsm5WMwsOReLmSXnYjGz5FwsZpaci8XMknOxmFlyLhYzS87FYmbJNVUsktok7VVWGDMbGN7y81gk3QR8HtgOLAP2kvSPEfGdssP1KgJt7a5l6N112pSP1R2hKX+87Km6IzRtyUcPrztC07aPG1t3hOasXlB40yIzliMj4lXgTOBu4CDgU7sVzMwGhSLFMlTSULJimRcRXUCUG8vMWlmRYrkWeI7scqgLJR0IvFpmKDNrbW95jiUirgaublj1vKQTyotkZq3uLWcski6WtJcyP5C0Ajixgmxm1qKKHApdlJ+8/TCwD9mJ22+VmsrMWlqRYlF+eyrwo4hY07DOzOwNihTLckn3kRXLvZLGkF283cysV0UuWPYZYCLwbERslrQf8OflxjKzVlbkVaEdktYCh0kaXkEmM2txRd7S/1ngYuD3gMeA44Cl+JUhM+tDkXMsFwOTgecj4gTgfcDLpaYys5ZWpFi2RMQWAElvi4ingNb7jS8zq0yRk7f/JWlv4Hbgfkm/BZ4vN5aZtbIiJ28/ni9eKelBYCxwT6mpzKyl9VkskvbtZfWq/HY0sKGURGbW8t5sxrKc7OMRGt9lu/N+AO8pMZeZtbA+iyUiDqoyiJkNHEV+u/njksY23N9b0pmpAkjamGpfZtY/FHm5+YqIeGXnnYh4GbiivEhm1uqKFEtv2xR5mRoASV+WNDNfni3pgXz5REk3Nmy3v6Slkk4rum8z65+KFEunpO9KOjj/+i7Zid2iFgFT8+UOYHT+GbpTgYUAkt4B3AVcHhF39dyBpOmSOiV1buve3MTQZlaHIsXyJWAbMBe4BdgCzGhijOXApPx6RFvJfs+og6xYFgFDgfnAVyLi/t52EBFzIqIjIjqGDRnZxNBmVocib5DbBHx1dweIiK78t6MvBB4GVgInAIcATwLdZOVzMvDQ7o5jZv1HVZdYXQTMIjv0WUR2AbRHIyLI3hNzEXCEpEsqymNmJaqyWN4JLI2IF8kOpxbtfDAitgPnAidK+mJFmcysJIVf3dkTETGf7FzKzvuHNSyPzm+3kh0OmVmLK/IGucMkzZe0Or9/jKRLy49mZq2qyKHQdcDXgC6AiFgJnFNmKDNrbUWKZWREPNJjXXcZYcxsYChSLL+RdDD5heAlTQNeKDWVmbW0IidvZwBzyF4O/m9gLXBBqanMrKUVeYPcs8CHJI0C2iLitfJjmVkrK3L5j8t73AcgIv6upExm1uKKHAptalgeDpxO9lZ8M7NeFTkU+ofG+5KuAu4tLZGZtbzdeUv/SLKrIpqZ9arIOZZV5C81A+3AOMDnV8ysT0XOsZzesNwNvBgRfoOcmfXpTYtFUjtwb0QcUVEeMxsA3vQcS/5xBk9LendFecxsAChyKLQPsEbSIzS89BwRHy0tlZm1tCLFclnpKcxsQClSLKdGxC4fGSnp29T0+bTRJnaMGV7H0LutfUNrnet+6JgRdUdo2veev6nuCE2bedy0uiM0RV1dhbct8j6WP+ll3SmFRzCzQafPGYukLwBfBN4jaWXDQ2OAJWUHM7PW9WaHQjcBdwN/z66X/3gtIjaUmsrMWlqfxZJfr/kVsk/PNzMrrKrLf5jZIOJiMbPkXCxmlpyLxcySc7GYWXIuFjNLzsViZsm5WMwsOReLmSXnYjGz5FwsZpaci8XMknOxmFlyLhYzS66yYpF0paRZVY1nZvXxjMXMkiu1WCR9XdIvJS0GDs/XLZA0W1KnpCclTZZ0m6RnJH2jzDxmVo0in9K/WyRNAs4BJubjrACW5w9vi4gOSRcDdwCTgA3AryTNjoj1ZeUys/KVOWOZCvw0IjZHxKvAvIbHdi6vAtZExAsRsRV4Fhjfc0eSpucznM6u7s0lRjazFOo6x7I1v93RsLzz/htmURExJyI6IqJj6JCRVeQzsz1QZrEsBM6UNELSGOCMEscys36ktHMsEbFC0lzgcWAdsKysscysfymtWAAi4pvAN3usvqrh8QXAgob7x5eZx8yq4fexmFlyLhYzS87FYmbJuVjMLDkXi5kl52Ixs+RcLGaWnIvFzJJzsZhZci4WM0vOxWJmyblYzCw5F4uZJediMbPkXCxmlpyLxcySc7GYWXKlfoJcGdS1nfYXNtQdoykx4m11R2hK+7hxdUdo2ozDTqo7QtP+9T/+re4ITTn51JcLb+sZi5kl52Ixs+RcLGaWnIvFzJJzsZhZci4WM0vOxWJmyblYzCw5F4uZJediMbPkXCxmlpyLxcySc7GYWXIuFjNLzsViZsm5WMwsOReLmSXnYjGz5CorFkkTJK2uajwzq49nLGaWXNXF0i7pOklrJN0naYSkmZKekLRS0i0V5zGzElT9Kf2HAudGxOck/Rg4C/gqcFBEbJW0d29PkjQdmA4wvH1MZWHNbPdUPWNZGxGP5cvLgQnASuBGSRcA3b09KSLmRERHRHQMaxtRTVIz221VF8vWhuXtZDOm04BrgPcDyyS13LWOzGxXdZ+8bQPGR8SDwCXAWGB0vZHMbE/VPTtoB26QNBYQcHVEFL/cmpn1S5UVS0Q8BxzdcP+qqsY2s2rVfShkZgOQi8XMknOxmFlyLhYzS87FYmbJuVjMLDkXi5kl52Ixs+RcLGaWnIvFzJJzsZhZci4WM0vOxWJmyblYzCw5F4uZJediMbPkXCxmlpwiou4MTZH0EvB8SbvfH/hNSfsuQ6vlhdbL3Gp5obzMB0bEuCIbtlyxlElSZ0R01J2jqFbLC62XudXyQv/I7EMhM0vOxWJmyblYdjWn7gBNqjSvpI357QGSbn2Lbf9K0sheHuozs6TjJd3ZRJ4Fksqe8rfazwT0g8w+xzLISWqPiO0Ft90YEYUuKCfpOaAjIgqfRJR0PDArIk4vuP2CfPvOomNYNTxjGaAkTZD0lKQbJT0p6dadMwhJz0n6tqQVwNmSDpZ0j6TlkhZJOiLf7iBJSyWtkvSNHvtenS+3S7pK0mpJKyV9SdJM4ADgQUkP5tt9ON/XCkk/kTQ6X/+RPOcK4BN9/FneMEYv23xfUqekNZL+tmH9tyQ9kT/vqnzd2fm+Hpe0MM133HYREf4agF/ABCCAKfn9fyH73x3gOeArDdvOBw7Nlz8IPJAvzwM+nS/PADY27Ht1vvwF4FZgSH5/34Yx9s+X9wcWAqPy+5cAlwPDgV8Dh5JdCfPHwJ29/Fn6GmMB2ayocV17vv4YYD/gaX43M987v10FvKtxnb/SfnnGMrD9OiKW5Ms3AH/U8NhcgHzm8IfATyQ9BlwLvDPfZgpwc778oz7G+BBwbUR0A0TEhl62OQ44EliSj/FnwIHAEcDaiHgmsn/lN+zBGH+az3oeBY7Kx3sF2AL8QNIngM35tkuAH0r6HFkRWWJ1X7vZytXzBFrj/U35bRvwckRMLLiP3SHg/og4d5eVUl9jNrdz6SBgFjA5In4r6YfA8IjolvQB4CRgGvCXwIkR8XlJHwROA5ZLmhQR61NksYxnLAPbuyX9Qb58HrC45wYR8SqwVtLZAMocmz+8BDgnXz6/jzHuB/5C0pD8+fvm618DxuTLvwCmSDok32aUpMOAp4AJkg7Ot9uleAqMsdNeZEX5iqR3AKfk240GxkbEz4G/Bo7N1x8cEf8eEZcDLwHj+xjXdpOLZWB7Gpgh6UlgH+D7fWx3PvAZSY8Da4CP5esvzp+/CnhXH8/9Z+A/gZX588/L188B7pH0YES8BFwI3CxpJbAUOCIitgDTgbvyw5h1TY4BQEQ8TnYI9BRwE1khQlZsd+ZjLgb+Jl//nfyE9GrgYeDxPsa13eSXmwcoSRPIToQeXXMUG4Q8YzGz5DxjMbPkPGMxs+RcLGaWnIvFzJJzsZhZci4WM0vOxWJmyf0fFeFaBj1BNuAAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ "cm = np.eye(M_classes) + 0.2 * np.random.uniform(size=(M_classes, M_classes))\n", "cm /= np.sum(cm, axis=1)\n", @@ -219,9 +278,23 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.11284647 0.80101817 0.03185446 0.03936846 0.01491243]\n", + " [0.09973749 0.07725771 0.60755048 0.08775271 0.12770162]\n", + " [0.08582221 0.08195349 0.05213019 0.02797892 0.75211519]\n", + " ...\n", + " [0.07873401 0.08532457 0.05993838 0.03413578 0.74186726]\n", + " [0.05587737 0.46613713 0.36092296 0.04309275 0.0739698 ]\n", + " [0.09950859 0.07750803 0.61382232 0.07556721 0.13359384]]\n" + ] + } + ], "source": [ "C = proclam.classifiers.from_cm.FromCM(seed=None)\n", "predictionC = C.classify(cm, truth, other=False)\n", @@ -238,7 +311,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -282,7 +355,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -293,7 +366,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -344,9 +417,21 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.15842418 0.18857113 0.16013153 0.17989965 0.01515237]\n", + " [0.17176994 0.13661988 0.13941695 0.09105078 0.10789923]\n", + " [0.17019976 0.06886277 0.77397623 0.01511753 0.06713813]\n", + " [0.04351717 0.12580389 0.14106496 0.06285216 0.11175521]\n", + " [0.06891721 0.0174588 0.12293988 0.07999228 0.12880625]]\n" + ] + } + ], "source": [ "# N = 3 #len(truth)\n", "# M = len(cm)\n", @@ -372,9 +457,32 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAARMAAAD0CAYAAAC4n8I2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGuBJREFUeJzt3X2UXVWd5vHvkxAISQDBqK2AktYgIo0oAe1hmAZ8i6LgC3YDOtOOthlHELpZOOISo8PYq1ubkWnXopXYutr2DdS2NS3R6CDIi4h5EQIJojGoJM0IAUFezFvVM3+cU/FWUVX3VO65de+pej5rnVXnnLvr7B83yY999tlnb9kmIqJTM3odQERMDUkmEVGLJJOIqEWSSUTUIskkImqRZBIRtUgyiYhaJJlERC2STCKmIUmLJd0laaOki0b5/DJJt5bbTyU91PaaGQEb0f9eefJcP/DgQKWya9ZtX2l78VifS5oJ/BR4ObAZWAWcZXvDGOXfDbzQ9tvGq3evStFFRE9tfXCAW1YeUqnsrKf/fH6bIscDG21vApB0JXA6MGoyAc4CPtiu3iSTiEYwAx6sWni+pNUtx8tsL2s5Phi4p+V4M/Di0S4k6VnAAuB77SpNMoloAAODVO6S2Gp7UU1Vnwl81Xbbe6wkk4iGGKRyy6SdLcChLceHlOdGcyZwTpWLJplENIAxA/U9LFkFLJS0gCKJnAmcPbKQpCOAA4Gbq1x02j4alvRor2Pod738jiR9SNKFvaq/JY7DJN3R6ziguM2psrVjexdwLrASuBP4su31ki6RdFpL0TOBK13xkW9aJhENYGCgep9J++vZK4AVI84tHXH8oYlcc1q0TCS9RdKPygE4V5TP2YcG5qyXdI2kp0xCHB8oBwrdKOlLki6UdF0Zx2pJd0o6TtLXJP1M0oe7HM97JJ1X7l8m6Xvl/imSvtBSbr6kmyWd2uV43l8OkLoReG55rmffT4uZkj5V/l35jqR9JZ0naYOkdeWj1a4ysNODlbZemfLJRNLzgD8DTrB9DDAAvBmYC6y2/Xzg+1R4jt5hHMcBbwReALwKaO1t31H2vn8S+AZFh9dRwFslPbmLYd0AnFjuLwLmSZpVnru+jPtpwNXAUttXdysQScdSNKuPAV4NHNfyca++nyELgcvLvysPUfw5XkQxkOto4J2TEAODFbdemQ63OS8FjgVWSQLYF7iP4nu/qizzeeBrXY7jBOAbtrcB2yT9W8tny8uftwPrbd8LIGkTRa/7A12KaQ1wrKT9ge3AWoqkciJwHjALuAY4x/b3uxTDkBOBf7X9OICk5S2f9er7GXK37VvL/TXAYcA64AuSvg58vcv1Fx2wNd7mdMOUb5kAAj5r+5hye+4Y94K9/JPaXv4cbNkfOu5awre9E7gbeCvwA4qWysnAcyg65nZR/ON5ZbdiqKgn388o9UPRst0LOBW4HHgRxf+ouhuHYaDi1ivTIZlcA5wh6akAkg4qR/XNAM4oy5wN3NjlOG4CXitptqR5wGu6XF9VNwAXUtzW3EDRZP9x2YNv4G3AEZLe2+U4rgdeV/ZH7Ae8tsv1dWIGcKjta4H3AgcA87pZYTFoLbc5PWV7g6SLge9ImgHspLjnfgw4vvzsPop+lW7Gsapsuq8Dfk3RZH+4m3VWdAPwfuBm249J2laeA8D2gKSzgOWSHrH9D90IwvZaSVcBt1H8eazqRj01mQl8XtIBFC3fj9tu+1ZtZ8QA6m4VHcpbw5NI0jzbj0qaQ/F/4iW21/Y6ruh/Rx29t//l6nbv7xWOeOa9a2ocTl/ZlG+Z9Jllko4EZlP04ySRRCUGdvR5r0SSySSy/YQhyxFVDbq/b3OSTCIaoBgBm2QSER0yYqDPb3P6O7oukrSk1zG06qd4+ikWSDxDBq1KW69M22QC9NVfUPornn6KBRLP7tucKluv5DYnohHEgPv7//2NSyZ7zZ7rfeYd1PF19p57IHPnH9rxIJvBmR2HAsCs/Q5k36d1Hs+MahOYj6uu76auv/uz5h3InKd2Hs9e9z9WRzjMZg7766CO43mE32y1XeltdQM7qekvW5c0LpnsM+8gnvfav+p1GLttf1J/9bDv83D/DELcOae/vpunfvKWXocwzP8duOqXVcvaaZlERE0G82g4IjpVdMCmZRIRHcttTkTUoJiCIMkkIjpkxA7naU5E1GAwtzkR0akmdMD2d3QRAZQv+rnaVoWkxeWyKxslXTRGmT8tl/NYL+mL7a6ZlklEQ9TVAVuuG3U58HJgM8WE2Mttb2gpsxB4H8USMb8ZmkN5PEkmEQ1gU+ej4eOBjbY3AZSLiJ0ObGgp8w6KtYJ+U9Tv+9pdNLc5EY0gBituwPxyBcShbeRbzgcD97Qcby7PtTocOFzSTZJ+KGlxuwjTMoloAAM7XPmf69YaJpTei2Ilw5OAQ4DrJf3ReLPwJ5lENICpdeKjLRQrIQ45pDzXajNwy9BCbZJ+SpFcxlyCpKe3OZIe7WX9EU0ywIxKWwWrgIWSFkjam2KN5+UjynydolWCpPkUtz2bxrtoWiYRDWDqG7Rme5ekc4GVFAuKfcb2ekmXAKttLy8/e4WkDRRLor7H9rhrOk9aMpH0ForFsPcGbgHeVZ6/DHgF8P+AM23fP1kxRTRHvVMy2l4BrBhxbmnLvoELyq2SSbnNkfQ8iuU3T7B9DEWmezMwlyITPh/4PvDBMX5/yVDP9K5t9cyWFdEkQy2TKluvTFbL5KXAsRSDYwD2pVhPdhC4qizzeeBro/2y7WXAMqCW6QQjmijr5hREsRzm+4adlD4wolwSRcQobLFzsL+7OCerTXQNcMbQkFxJB0l6Vln/GWWZs4EbJymeiEYp5jOpPGitJyYl1dneIOli4DuSZgA7gXOAx4Djy8/uo+hXiYgnyExru9m+it/3jwyZN1n1RzRZ0QGbPpOIqEG/z2eSZBLRADUPp++KJJOIhsiE0hHRMRt2DiaZRESHitucJJOIqEFGwEZEx/JoOCJqktuciKhJL4fKV5FkEtEAxez0SSYR0SEjdg1mreGIqEFuc2rmmbDjgP75UufcN9jrEIbppz66+et+1+sQhtnxihf1OoThvjXyvdex5WlORNQmT3MionPOi34RUYOhmdb6WZJJREP0e8ukv2/CIgIoWia7BmdU2qqQtFjSXZI2SrpolM/fKul+SbeW21+0u2ZaJhENUOfkSJJmApcDL6dYU3iVpOW2N4woepXtc6teNy2TiIaocXb644GNtjfZ3gFcCZzeaXxJJhFN4KLPpMpWwcHAPS3Hm8tzI71R0jpJX5V0aLuLJplENMDQoLWKyWT+0HK65bZkD6r8N+Aw20cD3wU+2+4X0mcS0RAT6DPZanvROJ9vAVpbGoeU53az/UDL4T8CH21XaZJJRAMYMVDfHLCrgIWSFlAkkTMpVtTcTdLTbd9bHp4G3NnuokkmEQ1R16A127sknQusBGYCn7G9XtIlwGrby4HzJJ0G7AIeBN7a7rpJJhENYNc7aM32CmDFiHNLW/bfB7xvItdMMoloCPf5CNgkk4hGyIt+EVGTad0ykfQB4C3A/RSDZNYArwF+DJwIzAX+C8W92R9RDN+9uJsxRTTRtJ4cSdJxwBuBFwCzgLUUyQRgh+1Fks4HvgEcS9Fj/HNJl414xk056GYJwKz9DuxWyBH9qwETSndzBOwJwDdsb7P9CMWIuiHLy5+3A+tt32t7O7CJ4YNpALC9zPYi24v22nduF0OO6E+muM2psvVKr/pMtpc/B1v2h47TjxPxBP3fAdvNlslNwGslzZY0j6KvJCL2kF1t65WutQJsr5K0HFgH/JrilubhbtUXMdVN66c5wKW2PyRpDnA9sMb2p4Y+tH0dcF3L8UldjieikYpWx/ROJsskHQnMBj5re22X64uYsvq9z6SrycT22e1LRUQVg4PTOJlERD1Mbx/7VpFkEtEQPXxQU0mSSUQTpAM2ImrT502TJJOIhphSLRNJM4B5tn/bpXgiYgy9HN1aRdvh9JK+KGl/SXOBO4ANkt7T/dAiYogNHpxRaeuVKjUfWbZEXgd8C1gA/OeuRhURT9Dv7+ZUSSazJM2iSCbLbe+k77uCIqYgV9x6pEoyuQL4BcWsaNdLehaQPpOISVVtLpO+ns/E9seBj7ec+qWkk7sX0vhm/s48ef329gUnyewNW9oXmkRXr/l2r0PY7ZXPOKbXIQwz5w+e1usQOtPn9wNVOmDPLztgJenTktYCp0xCbBExxPXOtCZpsaS7JG2UdNE45d4oyZLGW24UqHab87ayA/YVwIEUna9/WyniiKhPTX0mkmYClwOvAo4Ezirf7h9Zbj/gfOCWKuFVSSZDqe7VwOdsr285FxGTxaq2tXc8sNH2Jts7gCuB00cp97+AjwDbqly0SjJZI+k7FMlkZZmtBqtcPCJqVN/TnIMplp4Zsrk8t5ukFwGH2r66anhVRsC+HTgG2GT7cUlPBv5r1QoiogamaqsDYL6k1S3Hy2wvq/rL5Uj3j1FhsfJWVZ7mDEq6Gzhc0uyJXDwi6jOBAWlbbY/XYbqF4UvKHFKeG7IfcBRwnSSAPwCWSzrNdmuSGqZtMpH0FxSdMIcAtwIvAW4mT3QiJld9j4ZXAQslLaBIImcCu2dFtP0wMH/oWNJ1wIXjJRKo1mdyPnAc8EvbJwMvBB6aaPQR0aGaOmBt7wLOBVYCdwJftr1e0iWSTtvT8Kr0mWyzvU0Skvax/RNJz93TCiNiDxhU42MP2yuAFSPOLR2j7ElVrlklmWyW9CTg68B3Jf0G+GWVi0dEXSo/9u2ZKh2wry93PyTpWuAAoH/GbEdMF30+nH7MZCLpoFFO317+nAc82JWIImJ0TU0mwBqK8FvbVkPHBv6wi3FFxEhNTSa2F0xGAJIetT1vMuqKaKyJDVrriSpvDb9e0gEtx0+S9LruhhURI8nVtl6pMs7kg+UgFgBsPwR8sGoFkt4j6bxy/zJJ3yv3T5H0hZZy8yXdLOnU6uFHTCNTYKa10cpMZFb7G4ATy/1FwLxyGsgTgesBJD0NuBpYOpEXiyKmk6nQMlkt6WOSnl1uH6PonK1qDXCspP2B7RRD8RdRJJMbgFnANcD/sP3d0S4gaYmk1ZJW79z52ASqjphC6puCoCuqJJN3AzuAqyjmPdgGnFO1gnIC6rsp3kD8AUUCORl4DsVQ3l0UCeeV41xjme1FthfNmjW3atURU0fVW5wetkyqDFp7DBhzWreKbgAuBN5GMVblY8Aa25bk8vxXJL3X9kc6rCtiaurzR8OTtWLPDcDTgZtt/5qidXPD0Ie2B4CzgFMkvWuSYopolH7vM5mUtYZtX0PRNzJ0fHjL/rzy53bGudWJmPb6vGWShcsjGkA1vzXcDVUGrR0u6RpJd5THR0u6uPuhRcQwU+BpzqeA9wE7AWyvo5iZKSImU9Of5gBzbP+onAtyyK4uxRMRY+hl52oVVZLJVknPpsx5ks4A7u1qVBHxRFMgmZwDLAOOkLSFYgDaW7oaVUQM1+PHvlVUGbS2CXiZpLnADNuPdD+siHiCpicTSUtHHANg+5IuxRQRo+j3R8NVbnNa36ybDbyG4p2aiIjdqtzm/O/WY0mXUqy3ERGTqem3OaOYQ7G6X0RMlgZ0wFYZAXu7pHXlth64C/g/3Q8tIoapcdCapMWS7pK0UdITZgWQ9M7y3/6tkm6UdGS7a1ZpmbymZX8X8OtyecGemLFjgH3u6Z/VSf/9Df01Sf9Rf98/L10f/JJHex3CMPc9r8/mwvnMBMvX1DKRNBO4HHg5sBlYJWm57Q0txb5o+5Nl+dMopg1ZPN51x00mZaUrbR/RSfAR0RlR623O8cDGctgHkq4ETgd2JxPbv20pP5cKqWzcZGJ7oGwKPdP2r/Yo7Ijo3MTeGp4vaXXL8TLby1qODwbuaTneDLx45EUknQNcAOwNnNKu0iq3OQcC6yX9iJbHxLb3eLX0iNgD1VsmW20v6rg6+3LgcklnAxcDfz5e+SrJ5AOdBhURNajvNmcLcGjL8SHlubFcCXyi3UWrTEHwatvfb92AV1f4vYioUY3TNq4CFkpaIGlviilFlg+rS1rYcngq8LN2F62STF4+yrlXVfi9iKhTTY+Gy6ex51IMPr0T+LLt9ZIuKZ/cAJwrab2kWyn6Tca9xYFxbnMk/XfgXcAfSlrX8tF+wE3tQ46I2tQ88ZHtFcCKEeeWtuyfP9Frjtdn8kXgW8DfMHypi0dsPzjRiiKiM4190a9cX/hhiiUoIqLH+n04fWanj2iKJJOI6FiPJ4uuIskkogFUbv0sySSiKdIyiYg6pAO2JOlDwKO2L52sOiOmlKY+Go6IPjIVZlrrhKT3S/qppBuB55bnrpN0maTVku6UdJykr0n6maQPdzOeiEabAsuD7hFJx1K8QHRMWc9aYE358Q7biySdD3wDOBZ4EPi5pMtsPzDiWkuAJQCz99q/WyFH9LXp3DI5EfhX24+Xsza1vpU4tH87sN72vba3A5sY/mo0ALaX2V5ke9HeM+d0MeSIPjZdWyZtbC9/DrbsDx2nHydiFNO5ZXI98DpJ+0raD3htF+uKmNqqtkqmYsvE9lpJVwG3AfdRTMgSEXtANPit4TrY/mvgr0ecvrTl8+uA61qOT+pmPBGN1ue3OemfiGgIub+zSZJJRBPkreGIqEu/P81JMoloiiSTiKhDWiYR0bmJLQ/aE0kmEU3R5y2Trr41HBH1ELWu6IekxZLukrRR0kWjfH6BpA2S1km6RtKz2l0zySSiKexqWxuSZgKXU6zMeSRwlqQjRxT7MbDI9tHAV4GPtrtukklEQ9TYMjke2Gh7k+0dFAuTn95awPa1th8vD39Isbj5uJJMIppgYi/6zS8nHxvaloy42sHAPS3Hm8tzY3k7xeqe40oHbERDTOBpzlbbi2qpU3oLsAj4k3ZlG5dMBveZye8WHNjrMHZ7+ufu6HUIw/zqXUf1OoTdHn/Gvr0OYZjHDu73lWfGV+Oj4S0Mn4TskPLc8PqklwHvB/6knLxsXLnNiWgCU1sHLMV0IAslLZC0N8X0qq0zISLphcAVwGm276ty0ca1TCKmq7pGwNreJelcYCUwE/iM7fWSLgFW214O/B0wD/iKJIBf2T5tvOsmmUQ0RY2D1myvAFaMOLe0Zf9lE71mkklEAwwNWutnSSYRTVC9P6RnkkwiGiIv+kVELXKbExGdMzDY39kkySSiKfo7lySZRDRFbnMioh55mhMRdUjLJCI6JoPSARsRtejzcSaT9tawpMMk9df7+hENIrvS1itpmUQ0QQOWB53s+UxmSvqUpPWSviNpX0nntcyCfeUkxxPREBXnMplGLZOFwFm23yHpy8AbgYuABba3S3rSaL9UzmG5BGCf2aMWiZjy+v1pzmS3TO62fWu5vwY4DFgHfKGca3LXaL9ke5ntRbYXzdp77uREGtFv+rxlMtnJpHUeyQGKltGpFGt4vAhYJSn9OBEjGTTgSluv9HoO2BnAobavBd4LHEAxVVxEjFR9qYue6HUrYCbweUkHUEwm9XHbD/U4poi+1MvHvlVMWjKx/QvgqJbjSyer7ogpIckkIjpm+n4EbJJJRAOI3o5urSLJJKIp+jyZ9PppTkRUYWDA1bYKJC2WdJekjZIuGuXz/yRpraRdks6ocs0kk4iGqOtFP0kzKcZ2vQo4EjhL0pEjiv0KeCvwxarx5TYnoinqu805HthoexNA+U7c6cCG31flX5SfVe72TcskohFqfdHvYOCeluPN5bmOpGUS0QRmIi2T+ZJWtxwvs72s/qCGSzKJaIrq40y22l40zudbgENbjg8pz3UktzkRDVHjTGurgIWSFkjaGzgTWN5pfEkmEU1gYGCw2tbuUvYu4FxgJXAn8GXb6yVdIuk0AEnHSdoMvAm4QtL6dtfNbU5EI9Q7V4ntFcCKEeeWtuyvorj9qaxxyeTR327Zev23L/plDZeaD2yt4Tp1qSeev+k8EKbqd/MvnQdSquv7edaESvf5CNjGJRPbT6njOpJWt+mkmlT9FE8/xQKJZ7ckk4jomIEswhURnTO4v+cgmM7JpOuDeCaon+Lpp1gg8fz+aU4fm7aPhidjROBE1BmPpEfLn8+Q9NU2Zf9S0pyJxCLpJEnfnEA810na4z6GqfxnNcGKMzt9dK5803NCbP+77Xavj/8lMKdNmegHSSYxnnIN5p9I+oKkOyV9dailIOkXkj4iaS3wJknPlvRtSWsk3SDpiLLcAkk3S7pd0odHXPuOcn+mpEsl3VGunvhuSecBzwCulXRtWe4V5bXWSvqKpHnl+cVlnGuBN4zx3/KEOkYp8wlJq8tVHf9ny/m/bVnZ8dLy3JvKa90m6fp6vvGmyop+Uc1zgbfbvknSZ4B3AUMTbj9g+0UAkq4B3mn7Z5JeDPwDcArw98AnbP+zpHPGqGMJxaJnx9jeJekg2w9KugA42fZWSfOBi4GX2X5M0nuBCyR9FPhUWddG4KqqdYxS5v1lvTOBayQdTfFeyOuBI2y7ZWXHpcArbW8Za7XHacPAYPpMor17bN9U7n8e+I8tn10FULYQ/gPwFUm3AlcATy/LnAB8qdz/3Bh1vAy4ohxKje0HRynzEorJcm4q6/hzioFVR1Csxvgz2y5j3NM6/rRs3fwYeH5Z38PANuDTkt4APF6WvQn4J0nvoFgWZXpLyyQqGPk3oPX4sfLnDOAh28dUvMaeEPBd22cNOymNVefELi4tAC4EjrP9G0n/BMwuWzHHAy8FzqB4b+QU2+8sW2CnAmskHWv7gTpiaaQ+H7SWlkl/eKakPy73zwZuHFnA9m+BuyW9CUCFF5Qf30Tx5ifAm8eo47vAfxtafrXlFuQRYL9y/4fACZKeU5aZK+lw4CfAYZKeXZYblmwq1DFkf4rk+LCkp1FMGzjU6jqgfF/kr4AXlOefbfuW8p2R+xn+2vz0YuOBgUpbrySZ9Ie7gHMk3QkcCHxijHJvBt4u6TZgPcVUewDnl79/O2PPmPWPFPN6rit//+zy/DLg25KutX0/xbyfX5K0DriZoh9jG0V/yNXlLcp9E6wDANu3Udze/IRibtGhW7v9gG+Wdd4IXFCe/7uyU/kO4AfAbWPUOz0MutrWI3KfN52mOkmHAd+0fVSbojGNHbDXU/zH+53eviCw8qFPr+nFu0PpM4loArvvn+YkmfTYyDWYI8bU53cRSSYRDeG0TCKic70dQ1JFkklEExjo4WPfKpJMIhrAgDM5UkR0zJkcKSJq0u8tkwxai2gASd+mmBW/iq22F3czntEkmURELfJuTkTUIskkImqRZBIRtUgyiYhaJJlERC2STCKiFkkmEVGLJJOIqEWSSUTU4v8DXq5Ucvg2FXoAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ "plt.matshow(CMtunnel)\n", "plt.xticks(range(max(truth)+1), names)\n", @@ -395,7 +503,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -420,9 +528,32 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAARMAAAD0CAYAAAC4n8I2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGgtJREFUeJzt3X+YXmV95/H3h5AYSCBCoxQBJSI/iggoA+pStoKoKApaoQvIblmpKSsUul644KpoWXu1WlZ2vS5qiT9at4qg1koUSmApSEDU/BACCSAxSAlrxYAgYPNr5rN/nDPxmcnMPGfynOfXzOd1Xeeac85zP+d8M5P5zn3f537uW7aJiGjVTt0OICKmhiSTiKhFkklE1CLJJCJqkWQSEbVIMomIWiSZREQtkkwiohZJJhFRi527HUBENPeW4+f4yacGK5VdsWrTEtsntTmk7SSZRPSBDU8N8oMl+1YqO3Pvn8xvczhjSjKJ6Atm0EPdDmJCSSYRfcDAEL39odwkk4g+MURqJhHRImMGe3y6kGn7aFjSc92Oodd183sk6eOSLu7W/Rvi2F/S/d2OA4pmTpWtCkknSXpI0lpJl45T5g8krZG0WtI1za6ZmklEHzAwWFOfiaQZwFXAm4D1wDJJi22vaShzIPAh4Fjbv5T04mbXnRY1E0lnS/qhpHskXV1+M5F0ZZl1b5X0og7E8dHyr8Gdkr4q6WJJt5dxLJf0gKSjJX1T0sOSPtHmeD4o6cJy/0pJ/1zunyDpKw3l5ku6W9LJbY7nw5J+LOlO4ODyXNe+Pw1mSPpc+X/lZkm7SLqw/Ku9StK17Q7AwBYPVdoqOAZYa3ud7c3AtcCpo8q8D7jK9i8BbD/R7KJTPplI+h3gP1Bk2COBQeA9wBxgue1XAt8FPtbmOI4G3g0cAbwVGGh4ebPtAeBvgOuB84HDgHMk/VYbw1oKHFfuDwBzJc0sz91Rxr0XcANwme0b2hWIpKOAM4AjgbcBRze83K3vz7ADKX6xXgk8TfFzvBR4te3DgfM6EANDFTdgfpl8h7eFoy61D/BYw/H68lyjg4CDJN0l6fuSmg6Cmw7NnDcCR1FU5QB2AZ6g+L5fV5b5MvDNNsdxLHC97Y3ARknfbnhtcfn1PmC17Z8BSFoH7Ac82aaYVgBHSdod2ASspEgqxwEXAjOBW4HzbX+3TTEMOw74R9u/BpC0uOG1bn1/hj1i+55yfwWwP7AK+IqkbwHfavP9iw7Y6s2cDWXybcXOFEn0DcC+wB2SXmX76fHeMOVrJoCAL9k+stwOtv3xMcp1s6t8U/l1qGF/+LhtCd/2FuAR4BzgexQ1leOBVwAPAFspfnne0q4YKurK92eM+0NRs90ZOJmi3+E1FH+o2huHYbDiVsHjFEl42L7luUbrgcW2t9h+BPgxRXIZ13RIJrcCpw13IEnaU9LLKP7tp5VlzgLubHMcdwHvkDRb0lzg7W2+X1VLgYspmjVLKarsP3KxbIGB9wKHSLqkzXHcAbyz7I/YDXhHm+/Xip2A/WzfBlwCzAPmtvOGxaC1ys2cZpYBB0paIGkWRfNy8agy36KolSBpPkWzZ91EF53yzRzbayR9BLhZ0k7AFoo29/PAMeVrT1D0q7QzjmVl1X0V8HOKKvsz7bxnRUuBDwN3235e0sbyHAC2ByWdCSyW9Kztv25HELZXSroOuJfi57GsHfepyQzgy5LmUdR8PzNR9b8eYhDVciXbWyVdACyh+Ld80fZqSZdT9CMuLl97s6Q1FLWxD9qesDmprJvTOZLm2n5O0q4Uf4kX2l7Z7bii9x12+Cz/ww3VPr93yEt/tqKGPpNJm/I1kx6zSNKhwGyKfpwkkqjEwOYe75VIMukg22d1O4boX0Oup5nTLkkmEX2gGAGbZBIRLTJisMebOb0dXRuNMSqwq3opnl6KBRLPsCGr0tYt0zaZAD31H5TeiqeXYoHEs62ZU2XrljRzIvqCGHRv/+3vu2QySy/wbOa0fJ3Z7Mru2rP1QTa7zm75EgCzZ81j9zkvaTmerXNmtBzLzLl7sOuL92s5lrr+78/cbQ92+e3W45k5b3Md4TB7r93Y/eC9Wo7n2R8/scF2pU+rG9hC6z/bduq7ZDKbObxWb+x2GNvo0Fd2O4QRfv66ed0OYZvBF3Q7gpH2Pvlfuh3CCP/3hP/1aNWydmomEVGToTwajohWFR2wqZlERMvSzImIGhRTECSZRESLjNjsPM2JiBoMpZkTEa1KB2xE1MKIwUxBEBF1SAdsRLTMJo+GI6IOygjYiGidgc3u7V/X3o4uIoCiAzZzwE5A0nO227p4UcRUkUfDEdEy0/uD1joWnaSzJf1Q0j2SrpY0ozx/paTVkm6VVGmimIjpp9qUjd2ctrEjyUTS71Asv3ms7SMplht8DzCHYjnCVwLfBT42zvsXSlouafmWEWtIR0wPwzWTKlu3dOrObwSOolgt/p7y+OUU6yxfV5b5MvC7Y73Z9iLbA7YHZtJj03dFdEidNRNJJ0l6SNJaSZeO8fo5kn5RtiTukfRHza7ZqT4TUSyH+aERJ6WPjiqXhY8jxmCLLUP1/LqWXQxXAW8C1lP8kV9se82ootfZvqDqdTtVM7kVOE3SiwEk7SnpZeX9TyvLnAXc2aF4IvpKMZ+JKm0VHAOstb3O9mbgWuDUVmPsSDIpM95HgJslrQJuAfYGngeOkXQ/cAJweSfiieg/xUxrVbYK9gEeazheX54b7d2SVkn6hqT9ml20Y4+GbV/Hb/pHhmWMSUQFRQds5Sc18yUtbzheZHvRJG/5beCrtjdJ+mPgSxR/8MeVcSYRfWISg9Y22B6Y4PXHgcaaxr7luW1sP9lw+HngU81u2tujYCIC+M1w+prWGl4GHChpgaRZwBnA4sYCkvZuODwFeKDZRVMziegTdc1nYnurpAuAJcAM4Iu2V0u6nGLc12LgQkmnAFuBp4Bzml03ySSiD9iwZai+hoTtG4EbR527rGH/Q8CHRr9vIkkmEX2gaOb0dq9EkklEn+jm526qSDKJ6AOTfDTcFUkmEX0hzZyIqEnmgI2IlhWz0yeZRESLjNg6lLWGI6IGaebUTDvvzIz5L+52GNs8cmlvdYrtu+djzQt1yKxze+s/v2+Y3e0Qdlie5kREbfI0JyJaV/1DfF2TZBLRB4ZnWutlSSYRfSI1k4homYGtNX5quB2STCL6QNYajojapM8kIlrn9JlERA0yaC0iapNkEhEtM2IwT3Miog7pgI2IljkdsBFRFyeZRETrMmgtImoyrWsmkj4KnA38AngMWAG8HfgRcBwwB/hPFCuHvQq4zvZH2hlTRD+a1uNMJB0NvBs4ApgJrKRIJgCbbQ9Iugi4HjiKYj3Tn0i6ctQK7EhaCCwEmL3T3HaFHNG7+mBC6XY+uD4WuN72RtvPAt9ueG14xfX7gNW2f2Z7E7AO2G/0hWwvsj1ge2DWTru0MeSI3mSKZk6VrQpJJ0l6SNJaSZdOUO7dkixpoNk1u9Vnsqn8OtSwP3ycfpyI7dTXAStpBnAV8CZgPbBM0mLba0aV2w24CPhBleu2s2ZyF/AOSbMlzaXoK4mIHWRX2yo4Blhre53tzcC1wKljlPsfwCeBjVUu2rZkYnsZRXNmFfBPFE2aZ9p1v4iprsZmzj4UD0SGrS/PbSPpNcB+tm+oGl+7mxRX2P64pF2BO4AVtj83/KLt24HbG47f0OZ4IvpSUeuo3MyZL2l5w/Ei24uqvlnSTsCngXOqR9j+ZLJI0qHAbOBLtle2+X4RU9Yk+kw22J6ow/RxRj7o2Lc8N2w34DDgdkkAvw0slnSK7cYkNUJbk4nts9p5/YjpZGiotkfDy4ADJS2gSCJnANt+V20/A8wfPpZ0O3DxRIkE2tsBGxE1MdX6S6o0hWxvBS4AlgAPAF+zvVrS5ZJO2dEY8xg2ok9Ue1BT8Vr2jcCNo85dNk7ZN1S5ZpJJRD+YXAdsVySZRPSLOqsmbZBkEtEnplTNpHz+PNf2r9oUT0SMo+Lo1q5p+jRH0jWSdpc0B7gfWCPpg+0PLSKG2eChnSpt3VLlzoeWNZF3UgyLXwD8x7ZGFRHbqfGzOW1RJZnMlDSTIpkstr2Fnu8KipiCXHHrkirJ5GrgpxSzot0h6WVA+kwiOqq+QWvt0rQD1vZngM80nHpU0vHtC2liW/Z4Af/6rgO6dfvt/NkR13Q7hBH+9qjDux3CNoObN3c7hBE2nP2abocw0oOTLN/j7YEqHbAXlR2wkvQFSSuBEzoQW0QMc70zrbVDlWbOe8sO2DcDe1B0vv5lW6OKiO31eJ9JlXEmw6nubcDflx8I6u3RMxFTUY8PWqtSM1kh6WaKZLKknBdyqL1hRcR2pkDN5FzgSGCd7V9L+i3gP7c3rIgYwfR8zaTK05whSY8AB0ma3YGYImIMvT6cvmkykfRHFNPd7wvcA7wOuJs80YnorB5PJlX6TC4CjgYetX088Grg6bZGFRHbs6ptXVKlz2Sj7Y2SkPQC2w9KOrjtkUXEbxjU4489qiST9ZJeCHwLuEXSL4FH2xtWRIzU3VpHFVU6YN9V7n5c0m3APOCmtkYVEdvr8T6TcZOJpD3HOH1f+XUu8FRbIoqIsfVrMgFWUITfWLcaPjbw8jbGFRGj9Wsysb2gEwFIes723E7cK6Jv9cGgtSqfGn6XpHkNxy+U9M72hhURo8nVtm6pMs7kY+VygQDYfhr4WNUbSPqgpAvL/Ssl/XO5f4KkrzSUmy/pbkknVw8/Yhrp8c/mVEkmY5WZzKz2S4Hjyv0BYG45DeRxwB0AkvYCbgAus33DJK4dMW1MhZrJckmflnRAuX2aonO2qhXAUZJ2BzZRDMUfoEgmS4GZwK3Af7N9y1gXkLRQ0nJJy7f+2/OTuHXEFNLjI2CrJJM/ATYD1wHXAhuB86veoJyA+hHgHOB7FAnkeOAVFIsmb6VIOG+Z4BqLbA/YHth5lzlVbx0xdVRt4lSsmUg6SdJDktZKunSM18+TdJ+keyTdKenQZtdsmkxsP2/70vKX+Wjb/932ZKsHS4GLKZo1S4HzgB/ZHv7nvxc4RNIlk7xuxPRRUzKRNAO4CngrcChw5hjJ4hrbr7J9JPAp4NPNrtupFXuWAnsDd9v+OUXtZunwi7YHgTOBEyS9v0MxRfSVGvtMjgHW2l5nezNFi+PUxgKjVu2cQ4U01ZG1hm3fStE3Mnx8UMP+3PLrJiZo6kRMe/V1ru4DPNZwvB547ehCks4HPgDMosKUI91bSzAiKlP5qeEqGzB/+IFFuS3ckXvavsr2AcAlwEeala8yOdJBwGeBvWwfJulw4BTbn9iRACNiB1V/UrPB9sAErz8O7NdwvG95bjzXUuSACVWpmXwO+BCwBcD2KuCMCu+LiDrV9zRnGXCgpAWSZlH8Pi9uLCDpwIbDk4GHm120Sp/JrrZ/OGp1i60V3hcRNaprQJrtrZIuAJYAM4AvlkvYXA4st70YuEDSiRSViF8Cf9jsulWSyQZJB1DmPEmnAT/bwX9HROyoGke32r4RuHHUucsa9i+a7DWrJJPzgUUU40AepxiAdvZkbxQRLejyUPkqqsy0tg44UdIcYCfbz7Y/rIjYTr8nE0mXjToGwPblbYopIsYwFSaUbhw6Pxt4O8VnaiIitqnSzPmfjceSrqDoBY6ITur3Zs4YdqUY5BIRnTIVOmAl3cdvcuIM4EVA+ksiOq3fkwlFH8mwrcDPbXd10Jp36p2Jdf/mgtO7HcIIz542s3mhDtn8wt75OQHMfK7Hfxub6fHwJ0wm5bwHS2wf0qF4ImIMovebORN+NqecZ+QhSS/tUDwRMZbJfWq4K6o0c/YAVkv6IQ2PiW2f0raoImJ7PV4zqZJMPtr2KCKiuSmQTN5me8TcrJI+CXy3PSFFxFj6us+k9KYxzr217kAiookeX4Rr3JqJpP8CvB94uaRVDS/tBtzV7sAiokGXE0UVEzVzrgH+CfgLoHFdjWdtP9XWqCJiO337Qb9yfeFnKJagiIgu6/U+k44sdRERNUgyiYiW9XmfSUT0CJVbL0syiegXqZlERB3SAVuS9HHgOdtXdOqeEVNKvz4ajoge0gczrbV14XJJH5b0Y0l3AgeX526XdGW5oPIDko6W9E1JD0vK+sUR4+nX4fStknQUxRqmR5b3WQmsKF/ebHtA0kXA9cBRwFPATyRdafvJUddaCCwEmDl3j3aFHNHTpnPN5DjgH23/2vavGLkw8vD+fcBq2z+zvQlYx8jV2QGwvcj2gO2BnXeZ08aQI3pYj9dM2trMmcCm8utQw/7wcfpxIsYgV9sqXUs6SdJDktZKunSM1z8gaY2kVZJulfSyZtdsZzK5A3inpF0k7Qa8o433ipjaqtZKKiSTcm7nqyimEjkUOFPSoaOK/QgYsH048A3gU82u27ZkYnslcB1wL8Wnj5e1614RU52odQ7YY4C1ttfZ3gxcC5zaWMD2bbZ/XR5+nwprZbW1SWH7z4E/H3X6iobXbwdubzh+Qzvjiehr1ftD5kta3nC8yPaihuN9gMcajtcDr53geudSVAgmlP6JiD4hV84mG2wP1HJP6WxgAPi9ZmWTTCL6Qb1Pah5n5FPTfctzI0g6Efgw8Hvl09YJdetpTkRMUo1Pc5YBB0paIGkWxXiwxqEbSHo1cDVwiu0nqlw0ySSiX9T0NKdc3vcCYAnwAPA126slXS5peD2svwLmAl+XdI+kxeNcbps0cyL6RJ0jYG3fCNw46txlDfsnTvaaSSYR/cB9PKF0RPSYHv9sTpJJRB8Qvf9BvySTiH5RfZxJVySZRPSJ1EwionVZ6iIi6pKnOTWb+exW9r7tF90OY5uN+83rdggjvOgfVnc7hG3+7fUHdTuEEW772893O4QRZixqXqZRkklEtM6kAzYi6pEO2IioR5JJRLQqg9Yioh52+kwioh55mhMRtUgzJyJaZ2Cot7NJkklEv+jtXJJkEtEv0syJiHrkaU5E1CE1k4homQxKB2xE1KLHx5l0bN0cSftLur9T94uYamRX2rolNZOIftAHM611ekW/GZI+J2m1pJsl7SLpQklrJK2SdG2H44noE/7N53OabV3S6ZrJgcCZtt8n6WvAu4FLgQW2N0l64VhvkrQQWAgwe+buHQs2opf0+tOcTtdMHrF9T7m/AtgfWAV8RdLZwNax3mR7ke0B2wOzZuzamUgjek2P10w6nUw2NewPUtSMTgauAl4DLJOUfpyI0QwadKWtCkknSXpI0lpJl47x+r+XtFLSVkmnVblmp5PJWPffz/ZtwCXAPIqV1yNiNFfcmpA0g+IP+FuBQ4EzJR06qti/AOcA11QNr9u1gBnAlyXNo5hM6jO2n+5yTBE9qcbHvscAa22vAygffJwKrBkuYPun5WuVR7d0LJmUwR3WcHxFp+4dMSVUTybzJS1vOF5ku3FhjX2AxxqO1wOvbTG6rtdMIqIKM5kRsBtsD7QvmLElmUT0AVHr6NbHgf0ajvctz7Wk2x2wEVFVfY+GlwEHSlogaRZwBrC41fCSTCL6gYFBV9uaXcreClwALAEeAL5me7WkyyWdAiDpaEnrgdOBqyU1XXc2zZyIPlHnh/hs3wjcOOrcZQ37yyiaP5UlmUT0i8y0FhGtyyJcEVEHk2QSETXp8ZnWkkwi+kQ3Z1GrIskkoh8YGOztqkmSSURfSAds7X618V83LFnzF4/WcKn5wIaWr7KmeZGK6omnHvXEclPrgZRqiWfG3jVEUqjrZ/WySZVOMqmX7RfVcR1Jy7vxYajx9FI8vRQLJJ5tkkwiomUGsghXRLTO4HTA9qpFzYt0VC/F00uxQOLpi6c50/ZTw6Nmnuq6OuOR9Fz59SWSvtGk7J9KGjHlf7NYJL1B0ncmEc/tkna4j2Eq/6wmeePMTh+tKycBnhTb/892s5nF/xTI+iH9IMkkJlKuwfygpK9IekDSN4ZrCpJ+KumTklYCp0s6QNJNklZIWirpkLLcAkl3S7pP0idGXfv+cn+GpCsk3V+unvgnki4EXgLcJum2styby2utlPR1SXPL8yeVca4Efn+cf8t29xijzGclLS9XdfyzhvN/2bCy4xXludPLa90r6Y56vuP9Kiv6RTUHA+favkvSF4H3A8MTbj9p+zUAkm4FzrP9sKTXAn8NnAD8b+Cztv+PpPPHucdCikXPjrS9VdKetp+S9AHgeNsbJM0HPgKcaPt5SZcAH5D0KeBz5b3WAtdVvccYZT5c3ncGcKukwymmDHwXcIhtN6zseBnwFtuPj7fa47RhYCh9JtHcY7bvKve/DPxuw2vXAZQ1hH8HfF3SPcDVwPAwrGOBr5b7fz/OPU4Eri5n2cL2U2OUeR3FOip3lff4Q4qBVYdQrMb4sG2XMe7oPf6grN38CHhleb9ngI3AFyT9PvDrsuxdwN9Jeh/FsijTW2omUcHo/wGNx8+XX3cCnrZ9ZMVr7AgBt9g+c8RJabx7Tu7i0gLgYuBo27+U9HfA7LIWcwzwRuA0iikFT7B9XlkDOxlYIeko20/WEUtf6vFBa6mZ9IaXSnp9uX8WcOfoArZ/BTwi6XQAFY4oX76LYlJggPeMc49bgD8eXn61oQnyLLBbuf994FhJryjLzJF0EPAgsL+kA8pyI5JNhXsM250iOT4jaS+KFeWGa13zyqkE/ytwRHn+ANs/KKcT/AUjZ1SfXmw8OFhp65Ykk97wEHC+pAeAPYDPjlPuPcC5ku4FVlOswgZwUfn++ygWWBrL5ymWfFxVvv+s8vwi4CZJt9n+BcWSkF+VtAq4m6IfYyNFf8gNZRPliUneAwDb91I0bx6kWHZyuGm3G/Cd8p53Ah8oz/9V2al8P/A94N5x7js9DLna1iVyj1edpjpJ+wPfsX1Yk6Ixjc3b+UV+/W6nNi8ILHn6CyuyCFdEjM3u+ac5SSZdNnoN5ohx9XgrIskkok84NZOIaF1mWouIOhjo4mPfKpJMIvqAAWdypIhomTM5UkTUpNdrJhm0FtEHJN1EMSt+FRtsn9TOeMaSZBIRtchncyKiFkkmEVGLJJOIqEWSSUTUIskkImqRZBIRtUgyiYhaJJlERC2STCKiFv8fbW6dzibKLuUAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ "plt.matshow(CMbroadbrush)\n", "plt.xticks(range(max(truth)+1), names)\n", @@ -436,14 +567,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "**3) Cruise classifier CM:** \n", + "**3) Cruise control classifier CM:** \n", "\n", "This is where the confusion matrix has high values on the column of one specific class which means that the classifier constantly classifies all entries as one specific class" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -468,9 +599,32 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAARMAAAD0CAYAAAC4n8I2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGjxJREFUeJzt3X2UHXWd5/H3JyEhkISghnEUUDIIYnQQJaAuww4gKoqCjjgr6O6wPmRdYcDl6AJHRYZ1dtVhZcZzGIf4cMYZH0AdRzISRU4G5EHUkAiBBJEYVMJ6hPAkD4Yk3Z/9o+rC7aa7b3Vu3b630p/XOXW6qm51/b7pdL751a9+9S3ZJiKiWzP6HUBE7BySTCKiFkkmEVGLJJOIqEWSSUTUIskkImqRZBIRtUgyiYhaJJlERC126XcAEdHZ646e6/sfGKp07Oq1T1xp+7geh/Q0SSYRDbD5gSF+fOU+lY6d9ZxfLOxxOGNKMoloBDPk4X4HMaEkk4gGMDDMYD+Um2QS0RDDpGcSEV0yZmjAy4VM21vDkh7tdwyDrp8/I0nnS/pgv9pvi2M/Sbf1Ow4oLnOqLP2SnklEAxgYGvAxk2nRM5H0Tkk/kXSzpEskzSz3XyRpnaSVkvaagjg+KukOSddL+pqkD0q6pozjJkm3SzpM0rck3Snp4z2O50OSzijXL5L07+X6MZK+0nbcQkk3Sjq+x/F8WNLPJV0PvLDc17efT5uZkj5X/q58X9Juks6QtF7SWkmX9joAA9s8XGnpl50+mUh6EfCfgCNsHwIMAe8A5gI32X4x8APgYz2O4zDgrcBLgdcDS9o+3mp7CfAPwOXAacBLgFMlPauHYV0HHFmuLwHmSZpV7ru2jPvZwBXAebav6FUgkg4F3g4cArwBOKzt4379fFoOAC4uf1ceovh7PAd4me2DgfdNQQwMV1z6ZTpc5rwaOBRYJQlgN+Beip/7ZeUxXwa+1eM4jgAut70F2CLp39o+W15+vRVYZ/s3AJI2AvsC9/coptXAoZL2AJ4A1lAklSOBM4BZwErgNNs/6FEMLUcC/2r7cQBJy9s+69fPp+Uu2zeX66uB/YC1wFckfRv4do/bLwZgc5nTdwK+ZPuQcnmh7fPHOK6ff1NPlF+H29Zb2z1L+La3AXcBpwI/pOipHA28ALgd2E7xj+d1vYqhor78fMZoH4qe7S7A8cDFwMsp/qPqbRyGoYpLv0yHZLISOEnSHwBIeqak51P82U8qjzkFuL7HcdwAvEnSHEnzgDf2uL2qrgM+SHFZcx1Fl/2nLl5bYOBdwEGSzu5xHNcCby7HI+YDb+pxe92YAexr+2rgbGABMK+XDRaT1nKZ01e210v6CPB9STOAbRTX3I8Bh5ef3UsxrtLLOFaVXfe1wG8puuwP97LNiq4DPgzcaPsxSVvKfQDYHpJ0MrBc0iO2/74XQdheI+ky4BaKv49VvWinJjOBL0taQNHz/Yzth3rbpBhCvW2iS8p7c6aOpHm2H5W0O8X/xEttr+l3XDH4XnLwbP/LFdWe3zvoeb9ZXQ5YT6mdvmcyYJZJWgzMoRjHSSKJSgxsHfBRiSSTKWT7lH7HEM017MG+zEkyiWiAYgZskklEdMmIoQG/zBns6HpI0tJ+x9BukOIZpFgg8bQMW5WWfpm2yQQYqF9QBiueQYoFEs+TlzlVln7JZU5EI4ghD/b//Y1LJrN32d1zdt2z6/PMmb2APeY+t+tJNsO7zuw6FoDZu+/JvGfu23U8u//h413HMu8Pd+cPFj+r61i23D2n61gA5uy6gD3m7911PNqytY5wmDNjHgtm7dV1PL/bvnmz7UpPqxvYRj2/a73SuGQyZ9c9eeXiwen1Prqop7OoJ+2Qs2/ufNAU+fkHXtTvEEbY5Y67+x3CCFfed8mvqh5r19szkXQc8HcUs3k/b/sTYxzz58D5FLnslk5TGxqXTCKmq+GaxkPKej4XA68BNlE8qLjc9vq2Yw4AzqUo3fFg69m2iSSZRDRAMQBbW8/kcGCD7Y0AZXGnE4H1bce8l6KGy4MAtu/tdNLBHtGJiFJxmVNlARaWlelay+hxgb2B9mu+TeW+dgcCB0q6QdKPysuiCaVnEtEARQmCyv/3b67hQb9dKCrMHQXsA1wr6Y8nejo6ySSiAYzY6tru5txDUaGuZZ9yX7tNwI9bBbQk/ZwiuYxbGiKXORENMewZlZYKVgEHSFokaTZF7d3lo475NkWvBEkLKS57Nk500vRMIhqgzgFY29slnQ5cSXFr+Iu210m6gKLI+vLys9dKWk9RqvJDtiestZtkEtEARgzV+NyN7RXAilH7zmtbN3BWuVSSZBLREJMYgO2LJJOIBrDJszkRUQfVNgO2V5JMIhrAwFYP9j/XwY4uIoBiADY1YCcg6VHbg/XYbcSAGvSyjemZRDSAoeqEtL6ZsugkvVPSTyTdLOmS8jFoJF0kaZ2klZIqFYqJmH6qlWzsZ9nGKUkmkl5E8frNI2wfQjGj7h3AXIoZdy8GfgB8bJzvX9p6AnLb9u4riUU0TatnUtN0+p6YqsucVwOHUhRhAdiN4n2yw8Bl5TFfBr411jfbXgYsA2optRjRRHlvTkEUr8M8d8RO6aOjjkuiiBiDLbYND/YQ51T1iVYCJ7VKv0l6pqTnl+2fVB5zCnD9FMUT0ShFPRNVWvplSlKd7fWSPgJ8X9IMYBtwGvAYcHj52b0U4yoR8TR51cWTbF/GU+MjLZljElFBMQCbMZOIqEEmrUVE1zKdPiJqk3omEdE1G7YNJ5lERJeKy5wkk4ioQWbARkTXcms4ImqSy5yIqElqwEZE14rq9IOdTAa73xQRQHE3Z/vwzEpLFZKOk3SHpA2Szhnj81Ml3VcWM7tZ0ns6nTM9k4iGqOsyp6xyeDHwGooXlK+StNz2+lGHXmb79KrnbV4ymSGGdpvV7yieNH/DI/0OYYRX7bGh3yE86c4bt/c7hBF+/9qX9zuEkb5b/dCa7+YcDmywvRFA0qXAicDoZDIpucyJaIgayzbuDdzdtr2p3DfaWyWtlfRNSft2OmmSSUQTuHjQr8oCLGzVTC6XpTvQ4r8B+9k+GLgK+FKnb2jeZU7ENNSqtFbRZttLJvj8HqC9p7FPue+p9uz72zY/D3yqU6PpmUQ0xCR6Jp2sAg6QtEjSbODtwPL2AyQ9p23zBOD2TidNzySiAQxsr+mpYdvbJZ0OXAnMBL5oe52kCyhePbMcOEPSCcB24AHg1E7nTTKJaIC6iyPZXgGsGLXvvLb1c4FzR3/fRJJMIhoi0+kjonvOU8MRUYOUIIiI2iSZRETXjBhKDdiIqEMGYCOia84AbETUxUkmEdG9vNEvImoyrXsmkj4KvBO4j6J+wmrgjcBPgSOBucB/oZi2+8cUlZ0+0suYIppoWs8zkXQY8FbgpcAsYA1FMgHYanuJpDOBy4FDKR4m+oWki0Y9/kxZj2EpwK67LuhVyBGDa5oXlD4CuNz2FtuPUBRbaWk97nwrsM72b2w/AWxkZJ0FAGwvs73E9pLZs+b2MOSIwWSKy5wqS7/0a8zkifLrcNt6azvjOBFPM/gDsL3smdwAvEnSHEnzKMZKImIH2dWWfulZL8D2KknLgbXAbykuaR7uVXsRO7tpfTcHuND2+ZJ2B64FVtv+XOtD29cA17RtH9XjeCIaqeh1TO9kskzSYmAO8CXba3rcXsROa9DHTHqaTGyf0svzR0wnw8PTOJlERD1Mf2/7VpFkEtEQfbxRU0mSSUQTZAA2Imoz4F2TJJOIhtipeiaSZgDzbP+uR/FExDj6Obu1io7T6SV9VdIekuYCtwHrJX2o96FFRIsNHp5RaalC0nGS7pC0QdI5Exz3VkmWNNGL0IFqz+YsLnsibwa+CywC/nOliCOiNnU9myNpJnAx8HpgMXByObl09HHzgTOBH1eJr0oymSVpFkUyWW57GwM/FBSxE3LFpbPDgQ22N9reClwKnDjGcf8L+CSwpcpJqySTS4BfUlRFu1bS84GMmURMqWq1TMpB2oWSbmpblo462d4UlQ9bNpX7nmpNejmwr+0rqkbYcQDW9meAz7Tt+pWko6s20BMzB2dUe+uzdut3CCNcft8h/Q7hSTP3H6yqeLv8fnu/Q+hO9euBzbY7jnGMp7zR8mng1Ml8X5UB2DPLAVhJ+oKkNcAxOxZmROwQ11pp7R5GVjTcp9zXMh94CXCNpF8CrwSWdxqErXKZ865yAPa1wDMoBl8/USXiiKhRfWMmq4ADJC2SNBt4O0+VUsX2w7YX2t7P9n7Aj4ATbN800UmrJJNWqnsD8M+217Xti4ipYlVbOp3G3g6cDlwJ3A583fY6SRdIOmFHw6syaW21pO9T3BI+t7xdNLyjDUbEDqrxHqrtFcCKUfvOG+fYo6qcs0oyeTdwCLDR9uOSngX81yonj4iamEq9jn6qcjdnWNJdwIGS5kxBTBExhkGfTt8xmUh6D8UsuH2AmylGdm8kd3QiptaAJ5MqA7BnAocBv7J9NPAy4KGeRhURT1fTAGyvVBkz2WJ7iyQk7Wr7Z5Je2PPIIuIpBg34bY8qyWSTpD2BbwNXSXoQ+FVvw4qIkfrb66iiygDsW8rV8yVdDSwAvtfTqCLi6QZ8zGTcZCLpmWPsvrX8Og94oCcRRcTYmppMgNUU4bf3rVrbBv6oh3FFxGhNTSa2F01FAJIetT1vKtqKaKwGTFqr8tTwWyQtaNveU9KbextWRIwmV1v6pco8k4/Zfri1Yfsh4GNVG5D0IUlnlOsXSfr3cv0YSV9pO26hpBslHV89/IhppL6nhnuiSjIZ65jJVLW/DjiyXF8CzCvLQB4JXAsg6dnAFcB5k6nsFDGd7Aw9k5skfVrS/uXyaYrB2apWA4dK2gN4gmIq/hKKZHIdMAtYCfxP21eNdQJJS1sl6LZue2wSTUfsRAZ8BmyVZPKXwFbgMorCs1uA06o2UBagvouiBNwPKRLI0cALKGopbKdIOK+b4BzLbC+xvWT2rLlVm47YeVS9xOljz6TKpLXHgHHfq1HRdcAHgXdRzFX5NLDatiW53P8NSWfb/mSXbUXsnAb81nC1N/Z07zrgOcCNtn9L0bu5rvWh7SHgZOAYSe+fopgiGmXQx0ym5F3DtldSjI20tg9sW59Xfn2CCS51Iqa9Ae+Z5MXlEQ2gBjw1XGXS2oGSVkq6rdw+WNJHeh9aRIywE9zN+RxwLrANwPZaitL4ETGVmn43B9jd9k+kERmv4a9Gi2iefg6uVlElmWyWtD9lzpN0EvCbnkYVEU+3EyST04BlwEGS7qGYgPbOnkYVESP1+bZvFR3HTGxvtH0ssBdwkO0/sf3LnkcWESPVOGYi6ThJd0jaIOlpk1IlvU/SrZJulnS9pMWdzlnlVRfnjdoGwPYF1cKOiDrUdWtY0kzgYuA1wCZglaTltte3HfZV2/9QHn8Cxaz14yY6b5W7OY+1LUPA64H9JvsHiIiBcTiwobzq2ErxzN2J7QfY/l3b5lwq9HmqPJvzf9u3JV1I8cLjiJhK9Y2Z7A3c3ba9CXjF6IMknQacBcymwkv3duTZnN0p3u4XEVOl4nM55SDtwlbJjnJZukNN2hfb3h84G+g4UbXKmMmtPJUTZ1IMxGa8JGKqVe+ZbLa9ZILP7wH2bdvep9w3nkuBz3ZqtMqt4Te2rW8Hfmu7v5PWhgbnHtkuP5hMnajee88l6/odwpP+ljf1O4QRts6f1fmgQVbfr/0q4ABJiyiSyNuBU9oPkHSA7TvLzeOBO+lgwmRSjvpeafugHQo5Imoh6ptnYnu7pNMpxj5nAl+0vU7SBcBNtpcDp0s6luIxmgeBv+h03gmTie2h8l7082z/uvs/RkTskJqfGra9Algxat95betnTvacVS5zngGsk/QTitvDrcZOmGxjEdGFwbm6H1OVZPLRnkcREZ3tBMnkDbbPbt8h6ZPAD3oTUkSMpfHP5lBMuR3t9XUHEhEdNLWeiaT/Drwf+CNJa9s+mg/c0OvAIqJNnxNFFRNd5nwV+C7wfxj5qotHbD/Q06gi4mkGvQbsuMmkfL/wwxSvoIiIPhv0MZNUp49oiiSTiOhaw8dMImJAqFwGWZJJRFOkZxIRdcgAbEnS+cCjti+cqjYjdipNvTUcEQNkZ3jVRTckfVjSzyVdD7yw3HeNpIvKcnK3SzpM0rck3Snp472MJ6LRmjqdvluSDqWo4HRI2c4aoFWWbKvtJZLOBC4HDgUeAH4h6SLb948611JgKcCuuy7oVcgRA20690yOBP7V9uNl2fzlbZ+11m8F1tn+je0ngI2MrE0JgO1ltpfYXjJ71twehhwxwKZrz6SDJ8qvw23rre2M40SMYTr3TK4F3ixpN0nzYcCqC0c0SdVeyc7YM7G9RtJlwC3AvRQVsSNiB4gGPzVcB9t/Dfz1qN0Xtn1+DXBN2/ZRvYwnotEG/DIn4xMRDSEPdjZJMologjw1HBF1GfS7OUkmEU0x4Mmkp9PpI6I+crWl0rmk48q3dW6QdM4Yn58lab2ktZJWSnp+p3MmmUQ0Qfl60CpLJ+U7xC+meGXNYuBkSYtHHfZTYIntg4FvAp/qdN4kk4imqG/S2uHABtsbbW8FLgVOHNGUfbXtx8vNHwH7dDppkklEA4hJXeYsLJ/Kby1LR51ub+Dutu1N5b7xvJvitTcTygBsRFNUn2ey2faSOpqU9E5gCfCnnY5NMoloiBpvDd/DyKfz9yn3jWxPOhb4MPCn5VP9E8plTkQT1Pug3yrgAEmLJM2mqDvUXiIESS8DLgFOsH1vlZOmZxLREHU96Gd7u6TTgSuBmcAXba+TdAFwk+3lwN8A84BvSAL4te0TJjpv85LJY79n5o9u63cUT5qx1179DmGE1+6+rd8hPOlvZw3Wr9eM7QM+66uDOp8atr0CWDFq33lt68dO9pyD9bcdEWMzkxmA7Yskk4iGyLM5EVGPJJOI6FZr0togSzKJaAI7YyYRUY9pXQM2IuqTy5yI6J6B4cHOJkkmEU0x2LkkySSiKXKZExH1yN2ciKhDeiYR0TUZlAHYiKjFgM8zmbLiSJL2kzQ4tQMiGkZ2paVf0jOJaIIGvB50qss2zpT0OUnrJH1f0m6Szmh72c+lUxxPREP4qedzOi19MtU9kwOAk22/V9LXgbcC5wCLbD8hac+xvqks1b8UYA67T1mwEYNk0O/mTHXP5C7bN5frq4H9gLXAV8qS+tvH+ibby2wvsb1klnadmkgjBs2A90ymOpm0l8sfougZHU/xqsKXA6skZRwnYjSDhlxp6Zd+v+piBrCv7auBs4EFFBWxI2K0+l510RP97gXMBL4saQFFManP2H6ozzFFDKR+3vatYsqSie1fAi9p275wqtqO2CkkmURE10xmwEZE90S12a9VL4UkHSfpDkkbJJ0zxuf/UdIaSdslnVTlnEkmEU1R061hSTMp7qC+HlgMnCxp8ajDfg2cCny1ani5zIloAgP13fY9HNhgeyNAOfP8RGD9k80VY5xI1ctYJ5lENMQk7uYslHRT2/Yy28vatvcG7m7b3gS8osvwkkwiGqN6Mtlse0kvQxlLkklEI9Q6Vf4eYN+27X3KfV3JAGxEE5g6n81ZBRwgaZGk2cDbgeXdhphkEtEUwxWXDmxvB04HrgRuB75ue52kCySdACDpMEmbgLcBl0ha1+m8ucyJaIg6p9PbXgGsGLXvvLb1VRSXP5UlmUQ0gYGhwZ4Cm2QS0Qj9rVVSReOSySN+cPNV2y79VQ2nWghs7vos93YfSKmWeGY+p4ZI6vrZ8L+7P0Whnng6XvVXVtPPh+dP6ugkk3rZ3quO80i6qR/34sczSPEMUiyQeJ6UZBIRXTOQl3BFRPcMzgDsoFrW+ZApNUjxDFIskHgacTdn2k5aG/XgU9/VGY+kR8uvz5X0zQ7HfkDSiPeHdIpF0lGSvjOJeK6RtMNjDDvz39UkG051+uheWYNiUmz/P9udCtt8APIyokZIMomJlO9g/pmkr0i6XdI3Wz0FSb+U9ElJa4C3Sdpf0vckrZZ0naSDyuMWSbpR0q2SPj7q3LeV6zMlXSjptvLtiX8p6QzgucDVkq4uj3ttea41kr4haV65/7gyzjXAn43zZ3laG2Mc81lJN5Vvdfyrtv2faHuz44XlvreV57pF0rX1/MSbKm/0i2peCLzb9g2Svgi8H2gV3L7f9ssBJK0E3mf7TkmvAP4eOAb4O+Cztv9J0mnjtLGU4qVnh9jeLumZth+QdBZwtO3NkhYCHwGOtf2YpLOBsyR9Cvhc2dYG4LKqbYxxzIfLdmcCKyUdTPHE6luAg2y77c2O5wGvs33PeG97nDYMDGfMJDq72/YN5fqXgT9p++wygLKH8B+Ab0i6GbgEaE1ROwL4Wrn+z+O0cSxwSfmQF7YfGOOYV1KU8buhbOMvKCZWHUTxNsY7bbuMcUfb+POyd/NT4MVlew8DW4AvSPoz4PHy2BuAf5T0XorXokxv6ZlEBaN/A9q3Hyu/zgAesn1IxXPsCAFX2T55xE5pvDYnd3JpEfBB4DDbD0r6R2BO2Ys5HHg1cBLFE63H2H5f2QM7Hlgt6VDb99cRSyMN+KS19EwGw/MkvapcPwW4fvQBtn8H3CXpbQAqvLT8+AaKmhQA7xinjauA/9Z6/WrbJcgjwPxy/UfAEZJeUB4zV9KBwM+A/STtXx43ItlUaKNlD4rk+LCkZ1MUNG71uhaUT7L+D+Cl5f79bf+4fJr1PkYW9JlebDw0VGnplySTwXAHcJqk24FnAJ8d57h3AO+WdAvFkyYnlvvPLL//Vor6nmP5PEXF8bXl959S7l8GfE/S1bbvo6hI/jVJa4EbKcYxtlCMh1xRXqKM90TSeG0AYPsWisubn1FUPW9d2s0HvlO2eT1wVrn/b8pB5duAHwK3jNPu9DDsakufyAPeddrZSdoP+I7tl3Q4NKaxBbvs5VfNP7HzgcCVD31hdWrARsTY7IG/m5Nk0mej38EcMa4Bv4pIMoloCKdnEhHdS6W1iKiDgT7e9q0iySSiAQw4xZEiomtOcaSIqMmg90wyaS2iASR9j6IqfhWbbR/Xy3jGkmQSEbXIszkRUYskk4ioRZJJRNQiySQiapFkEhG1SDKJiFokmURELZJMIqIWSSYRUYv/D76H5qQMUnT9AAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ "plt.matshow(CMcruise)\n", "plt.xticks(range(max(truth)+1), names)\n", @@ -491,7 +645,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -526,9 +680,17 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2 3\n" + ] + } + ], "source": [ "CM = np.zeros((M_classes, M_classes))\n", "class_M = 2 #np.random.randint(0., M_classes, size=1)[0]\n", @@ -562,9 +724,32 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAARMAAAD0CAYAAAC4n8I2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGVZJREFUeJzt3XuUXlV5x/HvL5PEEBKiGHShRkKRSxEhmuCllBYQBUUpFmwBUamX1IpCS7HoApG2drVaKqushZZYXbUFJEqtRFEuK4IkGJUkQiBBRQm0WGoM98vKZWae/nHOG96ZzMy7J+857zln5vdZ66y8lzNnPwzhYe999tmPIgIzs25NqToAM5sYnEzMrBBOJmZWCCcTMyuEk4mZFcLJxMwK4WRiZoVwMjGzQjiZmFkhplYdgJl1dtzRu8cjjw4knbtm3dYbI+L4kkPaiZOJWQNsfnSAH934sqRzp+39y7klhzMiJxOzRggGYrDqIMbkZGLWAAEMUu+Hcp1MzBpiEPdMzKxLQTBQ8+1CJu2tYUlPVx1D3VX5O5J0saTzqmq/LY75ku6pOg7IhjkpR1XcMzFrgAAGaj5nMil6JpLOkPRjSXdKukJSX/75pZLWS1ouaa8exPFJST+TtFLSVyWdJ+nWPI7Vku6VdLikb0i6T9KnS47nY5LOzl9fKul7+etjJF3Vdt5cSasknVByPBdI+rmklcCB+WeV/X7a9En6Yv535SZJu0k6W9IGSeskXVN2AAFsj8GkoyoTPplI+m3gj4EjImIBMAC8C9gdWB0RrwS+D3yq5DgOB04GDgPeAixq+3pbRCwC/gW4DjgLOAQ4U9ILSwxrBXBk/noRMEvStPyz2/K4XwxcD1wUEdeXFYikhcCpwALgrcDhbV9X9ftp2R+4PP+78jjZv8ePA6+OiEOBD/UgBgYTj6pMhmHOG4GFwB2SAHYDNpH93pfm51wJfKPkOI4ArouILcAWSd9q+25Z/ufdwPqIeBhA0v3APOCRkmJaAyyUtAewFVhLllSOBM4GpgHLgbMi4vslxdByJPBfEfEsgKRlbd9V9ftp2RgRd+av1wDzgXXAVZK+CXyz5PazCVgPcyon4CsRsSA/DoyIi0c4r8p/U1vzPwfbXrfel5bwI2I7sBE4E/gBWU/laOAVwL1AP9l/PMeVFUOiSn4/I7QPWc92KnACcDnwGrL/UZUbR8BA4lGVyZBMlgOnSHoRgKQ9Je1D9s9+Sn7O6cDKkuO4HXi7pBmSZgFvK7m9VCuA88iGNSvIuuw/iaxsQQDvAw6SdH7JcdwGnJTPR8wG3l5ye92YAsyLiFuA84E5wKwyG8wWrXmYU6mI2CDpQuAmSVOA7WRj7meA1+bfbSKbVykzjjvyrvs64NdkXfYnymwz0QrgAmBVRDwjaUv+GQARMSDpNGCZpKci4vNlBBERayUtBe4i+/dxRxntFKQPuFLSHLKe72UR8Xi5TYoBVG4TXZLr5vSOpFkR8bSkmWT/J14cEWurjsvq75BDp8d/Xp/2/N5BL394TT5h3VMTvmdSM0skHQzMIJvHcSKxJAFsq/mshJNJD0XE6VXHYM01GPUe5jiZmDVAtgLWycTMuhSIgZoPc+odXYkkLa46hnZ1iqdOsYDjaRkMJR1VmbTJBKjVX1DqFU+dYgHHs2OYk3KkkHR8/pzYLyR9fITvL82fZbszf16q461vD3PMGkEMRDH/788fdL0ceBPwENkK3mURsaF1TkT8Rdv5HwVe3em6jUsmL9xzSsyb133YL3tpHwsOm971IpuN9xWzd++MaXOYM/Ml3S/66e/vPpa+WcyZ/qKuY4nt3ccCMIOZ7KE9u45H06cVEQ4z+mYz53kv7jqeJ7dt2hwRSU+rB7Cdvm6bbHkt8IuIuB8gf+r5D4ANo5x/GgkPwjYumcybN5WbvlPJ5tsjevfbPlB1CENM2fRY1SHs0P/w/1UdwhBTXzKv6hCGuOGBSx9MPTeiuJ4J8FLgf9rePwS8bqQT80dP9gW+1+mijUsmZpPVYPqt4bmSVre9XxIRS3ax2VOBayOiY9EeJxOzBsgmYJN7Jps7LKf/FdnWDS0vyz8byalkz7J15GRi1giFDnPuAPaXtC9ZEjmV7Mn5oS1KBwEvAFalXNTJxKwBsi0IikkmEdEv6SPAjWRPQH85ItZL+huy3Qdbm1GdClwTiU8DO5mYNUAgtkVhd3OIiO8A3xn22UXD3l88nms6mZg1xGBxw5xSOJmYNcA4J2Ar4WRi1gCBGPAWBGZWhKImYMviZGLWABEUeWu4FE4mZo2g8ayArYSTiVkDBLAt6v2fa72jMzMgm4D1HrBjkPR0RJRavMhsovCtYTPrWlD/RWs9i07SGZJ+nG8Dd0W+21Nre7j1kpZLStooxmzySduyscod7HuSTCT9Nln5zSMiYgFZ8ed3AbuTPVj0SuD7jLKbk6TFklZLWv3II1VWUzWrRqtnknJUpVfDnDcCC8n2mgTYjaye7CCwND/nSuAbI/1wvrHLEqCQrRbNmsh1czIiK4f5iSEfSp8cdp4ThdkIIsT2wXpPcfaqT7QcOEXSiwAk7ZnvLTkFOCU/53RgZY/iMWuUbD8TJR1V6Umqi4gNki4EbpI0BdhOthXcM8Br8+82kc2rmNlOCt1prRQ96zdFxFKemx9p8RoTswTZBKznTMysAF60ZmZd83J6MytM3fczqXd0ZgZk+5lsH5ySdKToVLg8P+ePJG3IV6hf3ema7pmYNUA2zOld4XJJ+wOfIFu1/lhrWcdY3DMxa4gCn83ZUbg8IrYBrcLl7T4IXB4RjwFExKZOF3UyMWuA1q3hlIO81nDbsXjY5UYqXP7SYeccABwg6XZJP5R0fKcYPcwxa4RxDXM61RpOMRXYHziKrBbxbZJeFRGPj/YD7pmYNUSBy+lTCpc/BCyLiO0RsRH4OVlyGZWTiVkDZLvTK+lIsKNwuaTpZDWFlw0755tkvRIkzSUb9tw/1kU9zDFrgED0DxZTazixcPmNwJslbSDbf+hjEfHIWNd1MjFriCKfCO5UuDwiAjg3P5I0Lpls/PkLec+x7606jB2mbH2m6hCGeMOND1Qdwg6rTjyg6hCGGNhrTtUhDPVA+ql+0M/MClP3DaWdTMyaIPygn5kVoLXTWp05mZg1hHsmZta1APoTnwiuipOJWQN4cyQzK4znTMyse+E5EzMrgBetmVlhnEzMrGuBGPDdHDMrgidgzaxr4QlYMytKOJmYWfe8aM3MCjKpeyaSPgmcAfyGbGv9NcDbgJ8ARwK7A+8hK/bzKmBpRFxYZkxmTTSp15lIOhw4GTgMmAasJUsmANsiYpGkc4DrgIXAo8AvJV06fK/JvO7HYoAZU/coK2Sz+so3lK6zMm9cHwFcFxFbIuIp4Ftt37V2wr4bWB8RD0fEVrLdr+cNuw4RsSQiFkXEoulTZ5YYslk9BdkwJ+VI0anWsKQzJf1G0p358YFO16xqzmRr/udg2+vWe8/jmO2kuAnYlFrDuaUR8ZHU65bZM7kdeLukGZJmkc2VmNkuikg7EqTUGh630pJJRNxBNpxZB3yXbEjzRFntmU10BQ5zUmoNA5wsaZ2kayXtNP0wXNmL/S+JiAOA44B9gDURcVRErAaIiFsjYkePpf07M3tO1utITiadCpen+BYwPyIOBW4GvtLpB8qen1gi6WBgBvCViFhbcntmE9Y45kw6FS7vWGt42B3VfwU+26nRUpNJRJxe5vXNJpPBwcJuDe+oNUyWRE4Fhvy3KmnviHg4f3sicG+ni/rOiVkDBOm3fTteK63W8NmSTgT6ydaAndnpuk4mZg2RdqMm8Vqdaw1/gmxlejInE7MmiEn+bI6ZFajIrkkJnEzMGmJC9UwkTQFmRcSTJcVjZqNIXN1amY6L1iRdLWkPSbsD9wAbJH2s/NDMrCUCYnBK0lGVlJYPznsiJ5Eti98XeHepUZnZTgp8NqcUKclkmqRpZMlkWURsp/ZTQWYTUCQeFUlJJlcAD5DtinabpH0Az5mY9VTaczlVTtJ2nICNiMuAy9o+elDS0eWFZGYjqvl4IGUC9px8AlaSviRpLXBMD2Izs5bxPTVciZRhzvvyCdg3Ay8gm3z9h1KjMrOd1XzOJGWdSSvVvRX4j/yBoHqvnjGbiGq+aC2lZ7JG0k1kyeRGSbPJ9mo1s16aAD2T9wMLgPsj4llJLwT+pNywzGyIoPY9k5S7OYOSNgIHSJrRg5jMbAR1X07fMZnk9TLOIdva7U7g9cAqfEfHrLdqnkxS5kzOAQ4HHoyIo4FXA4+XGpWZ7SyUdlQkZc5kS0RskYSk50XETyUdWHpkZvacANX8tkdKMnlI0vOBbwI3S3oMeLDcsMxsqGp7HSlSJmDfkb+8WNItwBzghlKjMrOdNXXORNKeww+yqnwrgVk9i9DMMgWuM+lUuLztvJMlhaSx6vAAY/dM1uShtfetWu8D+K20sM2sEAX1TFILl+cLVM8BfpRy3VGTSUTsu+vhppP0dES4p2M2lmIXre0oXA4gqVW4fMOw8/4W+AyQtLNiylPD75A0p+398yWdlBq1mRVDkXbQudZwx8Llkl4DzIuI61PjS1ln8qmIeKL1JiIeBz6V2oCkj0k6O399qaTv5a+PkXRV23lzJa2SdELqtc0mlfQ5k80RsajtWDKeZvKN4z8H/OV4fi4lmYx0znh2tV8BHJm/XgTMyreBPBK4DUDSi4HrgYvGkwnNJpNx9Ew66VS4fDZwCHCrpAfIVr0v6zQJm5JMVkv6nKT98uNzZJOzqdYACyXtAWwlW4q/iCyZrACmAcuBv4qIm0e6gKTFrS7btv5nx9G02QRS3ArYHYXLJU0nK1y+bEczEU9ExNyImB8R84EfAidGxOqxLpqSTD4KbAOWAtcAW4CzUiLOA9sObCQrfPwDsgRyNPAKssrq/WQJ57gxrrGk1WWbPnVmatNmE0fqECehZxIR/UCrcPm9wNdahcvzYuW7JGXR2jPAqPehE60AzgPeR7ZW5XPAmogISZF//nVJ50fEZ7psy2xiKnDRWqfC5cM+Pyrlmr2q2LMC2BtYFRG/JuvdrGh9GREDwGnAMZI+3KOYzBqlwDmTUvSk1nBELCebG2m9P6Dt9az8z62MMdQxm/RqvpzehcvNGkANeGo4ZdHaAZKWS7onf3+opAvLD83Mhqj5fiYpcyZfBD4BbAeIiHVkt5LMrJcmwIbSMyPix8OqW/SXFI+ZjaLKydUUKclks6T9yHOepFOAh0uNysx2NgGSyVnAEuAgSb8iW4B2RqlRmdlQFd/2TZGyaO1+4FhJuwNTIuKp8sMys500PZlIumjYewAi4m9KisnMRlD3W8Mpw5xn2l7PAN5Gtp7fzGyHlGHOP7W/l3QJ2QNCZtZLTR/mjGAm2f4HZtYrE2ECVtLdPJcT+4C9AM+XmPVa05MJ2RxJSz/w63w/hErEFDE4uz710/u212v93opD6/O7+fyDV1cdwhBnv/6UqkPoTpOTSb4l/o0RcVCP4jGzEYj6D3PGfDYn32fkZ5Je3qN4zGwk+VPDKUdVUoY5LwDWS/oxbbeJI2KXt3czs11Q855JSjL5ZOlRmFlnEyCZvDUizm//QNJngO+XE5KZjaTRcya5N43w2VuKDsTMOuhh4XJJH5J0t6Q7Ja2UdHCna46aTCT9Wb7G5EBJ69qOjcC6tJDNrBAFlrpoK1z+FuBg4LQRksXVEfGqiFgAfJasosSYxhrmXA18F/h7hpa6eCoiHu0cspkVqcA7NR0Ll0fEk23n705Cmho1meT1hZ8gK0FhZhUbx5zJXEnt1feWDKs3PFLh8tft1J50FnAuMB04plOj3p3erCnSk8nmiBizLnBScxGXA5dLOh24EHjvWOf3qgiXmXWjwDkTOhcuH+4a4KROF3UyMWsAjeNIMGbhcgBJ+7e9PQG4r9NFPcwxa4qC1plERL+kVuHyPuDLrcLlwOqIWAZ8RNKxZCVuHqPDEAecTMwao8hFa50Kl0fEOeO9Zs+SiaSLgacj4pJetWk2oUyAPWDNrGoN2Gmt1AlYSRdI+rmklcCB+We3SrpU0mpJ90o6XNI3JN0n6dNlxmPWaBOgPOgukbSQbJZ4Qd7OWmBN/vW2iFgk6RzgOmAh8CjwS0mXRsQjw661GFgMMGP6nLJCNqu1ydwzORL4r4h4Nl+a237rqfX6bmB9RDwcEVuB+xl6/xuAiFgSEYsiYtG0qTNLDNmsxiZrz6SDrfmfg22vW+89j2M2gsncM7kNOEnSbpJmA28vsS2zia3YFbClKK0XEBFrJS0F7gI2ka26M7NdICZGedBdFhF/B/zdsI8vafv+VuDWtvdHlRmPWaPVfJjj+QmzhlDUO5s4mZg1QcXzISmcTMwaou53c5xMzJrCycTMiuCeiZl1Lyb5rWEzK5B7JmbWLeFhjpkVxetMzKwIde+ZeHd6syYo+EG/hFrD50rakJcEXi5pn07XdDIxawgNph0dr5NWa/gnwKKIOBS4lqze8JgaN8zR9gH6Hq5PqeOYOaPqEIbo22uvqkPY4SMHHlt1CEN89/4bqg5hiL69x3d+j2sN39J2/g+BMzpd1D0TsyYIsgnYlKOzkWoNv3SM898PfLfTRRvXMzGbrAosXJ7epnQGsAj4/U7nOpmYNUVxhcuTag3nFf0uAH4/36N5TB7mmDVAa9FaypEgpdbwq4ErgBMjYlPKRd0zMWuC9PmQhEsl1Rr+R2AW8HVJAP8dESeOdV0nE7OGKPJBv4Raw+O+FedkYtYQdV8B62Ri1gQBDNY7mziZmDVFvXOJk4lZU3iYY2bF8BYEZlYE90zMrGsKkCdgzawQNd9QumfL6SXNl3RPr9ozm2gUkXRUxT0TsyZoQHnQXj/o1yfpi5LWS7pJ0m6Szm7bHu6aHsdj1hCJe5lMop7J/sBpEfFBSV8DTgY+DuwbEVslPX+kH5K0GFgMMKNvds+CNauTut/N6XXPZGNE3Jm/XgPMB9YBV+WbsPSP9EMRsSQiFkXEoulTdutNpGZ1U/OeSa+TSfsGKwNkPaMTyDa3fQ1whyTP45gNF6CBSDqqUvXmSFOAefnmtecDc8j2UDCz4QosdVGGqnsBfcCVkuaQbSZ1WUQ8XnFMZrVU5W3fFD1LJhHxAHBI2/tLetW22YTgZGJmXQtqvwLWycSsAUS1q1tTOJmYNUXNk0nVd3PMLEUAA5F2JEgoXP57ktZK6pd0Sso1nUzMGqKoB/0SC5f/N3AmcHVqfB7mmDVFccOclMLlD+TfJU/7umdi1gjjetBvrqTVbcfiYRcbb+HyJO6ZmDVBMJ6eSadaw6VwMjFriuLWmSQVLh8vD3PMGqLAndY6Fi7fFU4mZk0QwMBg2tHpUhH9QKtw+b3A11qFyyWdCCDpcEkPAe8ErpC0vtN1Pcwxa4Ri9ypJKFx+B9nwJ1njksmT2zdtvuGhyx4s4FJzgc0FXKcodYqnTrFAQfH07V1AJJmifj/7jOvsmq+AbVwyiYi9iriOpNVVzHiPpk7x1CkWcDw7OJmYWdcCcBEuM+teQNR7D4LJnEyWVB3AMHWKp06xgON57m5OjU3aW8MRUau/oEXGI+np/M+XSLq2w7l/LmnmeGKRdJSkb48jnlsl7fIcw0T+dzXOhr07vXUvf9JzXCLifyOi0+Pjfw7M7HCO1YGTiY0lr8H8U0lXSbpX0rWtnoKkByR9RtJa4J2S9pN0g6Q1klZIOig/b19JqyTdLenTw659T/66T9Ilku7Jqyd+VNLZwEuAWyTdkp/35vxaayV9XdKs/PPj8zjXAn84yj/LTm2McM4X8ofP1kv667bP/6GtsuMl+WfvzK91l6TbivmNN5Ur+lmaA4H3R8Ttkr4MfBhobbj9SES8BkDScuBDEXGfpNcBnweOAf4Z+EJE/Luks0ZpYzFZ0bMFEdEvac+IeFTSucDREbFZ0lzgQuDYiHhG0vnAuZI+C3wxb+sXwNLUNkY454K83T5guaRDyZ4LeQdwUEREW2XHi4DjIuJXo1V7nDQCGPSciXX2PxFxe/76SuB3275bCpD3EH4H+LqkO4ErgNYyrCOAr+av/2OUNo4FrsiXUhMRj45wzuvJNsu5PW/jvWQLqw4iq8Z4X0REHuOutvFHee/mJ8Ar8/aeALYAX5L0h8Cz+bm3A/8m6YNkZVEmN/dMLMHwvwHt75/J/5wCPB4RCxKvsSsE3BwRpw35UBqtzfFdXNoXOA84PCIek/RvwIy8F/Na4I3AKWTPjRwTER/Ke2AnAGskLYyIR4qIpZFqvmjNPZN6eLmkN+SvTwdWDj8hIp4ENkp6J4Ayh+Vf30725CfAu0Zp42bgT1vlV9uGIE8BrWrwPwSOkPSK/JzdJR0A/BSYL2m//LwhySahjZY9yJLjE5JeTLZtYKvXNSd/XuQvgMPyz/eLiB/lz4z8hqGPzU8uEcTAQNJRFSeTevgZcJake4EXAF8Y5bx3Ae+XdBewnmyrPYBz8p+/m9F3zPpXsn091+U/f3r++RLgBkm3RMRvyPb9/KqkdcAqsnmMLWTzIdfnQ5RN42wDgIi4i2x481OyvUVbQ7vZwLfzNlcC5+af/2M+qXwP8APgrlHanRwGI+2oiKLmXaeJTtJ84NsRcUiHU20SmzN1r3jD7D/ofCJw4+NfWuOd1sxsZBG1v5vjZFKx4TWYzUZV81GEk4lZQ4R7JmbWvWrXkKRwMjFrggAqvO2bwsnErAECCG+OZGZdC2+OZGYFqXvPxIvWzBpA0g1ku+Kn2BwRx5cZz0icTMysEH42x8wK4WRiZoVwMjGzQjiZmFkhnEzMrBBOJmZWCCcTMyuEk4mZFcLJxMwK8f81O2FPmxFvcwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ "plt.matshow(CMsubsumedto2)\n", "plt.xticks(range(max(truth)+1), names)\n", @@ -576,9 +761,32 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAARMAAAD0CAYAAAC4n8I2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGiZJREFUeJzt3X2UXXV97/H3J0NiyANRGhRRBIogRUQwidZyaQFRUQSxYC8gt1KpuVxRqBQucHmQ22vX9YHKravUEqpLKyBBaiWK8rAikRBR80AIJKggiIVSAgSQh5Wnme/9Y+8znJnMzPmdnL3PPnvm82LtNWef2fPbXyYz3/nt3/7t31cRgZlZpyZVHYCZjQ9OJmZWCCcTMyuEk4mZFcLJxMwK4WRiZoVwMjGzQjiZmFkhnEzMrBA7VB2AmbX23sOnx9Mb+pOOXblm0y0RcVTJIW3DycSsBp7a0M/Pbnl90rGTX/vr2SWHMyInE7NaCPpjoOogxuQxE7MaCGCASNpSSDpK0i8lPSjp/BE+/wZJt0u6W9IaSe9v1aaTiVlNDCT+14qkPuAK4H3A/sBJkvYfdthFwPURcTBwIvCPrdr1ZY5ZDQRBf3HLhbwdeDAiHgKQdB3wQWDdkFPCTvnrWcB/tGp0wvZMJL1QdQy9rsrvkaRLJZ1T1fmb4thT0n1VxwFtXebMlrSiaZs/rKnXAf/etP9o/l6zS4FTJD0K/AD4VKv43DMxq4EA+hPHQ4CnImJuh6c8Cfh6RPydpHcC35R0QMToo8ATomci6RRJP5e0WtKV+TUjki6XtFbSYkm7dCGOi/NBrzslfUvSOZKW5HGskHS/pHmSviPpAUmfLTmecyWdmb++XNKP8tdHSLqm6bjZku6SdHTJ8Vwo6VeS7gTelL9X2fenSZ+kq/KflVsl7SjpTEnr8sHJ68oOIIAtMZC0JXgM2L1p//X5e81OA64HiIi7gKnAmLecx30ykfQHwH8FDomIg4B+4CPAdGBFRLwZ+DHwmZLjmAccD7yVbOCr+S/H5vwvyT8BNwJnAAcAp0r6vRLDWgocmr+eC8yQNDl/74487tcANwGXRMRNZQUiaQ7ZQN9BwPuBeU2frur707APcEX+s/Is2b/j+cDBEXEgcHoXYmAgcUuwHNhH0l6SppB93xcNO+a3wLtg8HdoKvDkWI1OhMucdwFzgOWSAHYE1pN93xfmx1wNfKfkOA4BboyIjcBGSd9r+lzjH/JeYG1EPA4g6SGyvyBPlxTTSmCOpJ2ATcAqsqRyKHAmMBlYDJwRET8uKYaGQ4F/i4iXACQ1/3BX9f1peDgiVuevVwJ7AmuAayR9F/huyefPBmDTL3PGbitiq6RPArcAfcDXImKtpL8h+wO7CPhr4CpJnybrGJ0aLRaMngjJRMA3IuKCIW9KFw87rsqVtTflHweaXjf2S/s3iogtkh4GTgV+QvYLcjjwRuB+YCvZL897yXpvVank+zPC+SHr2e4IHA38MXAMcKGkt0TE1tIiCOgv8Cc0In5ANrDa/N4lTa/Xkf0BTDbuL3PI/rKeIOnVAJJ2lrQH2f/7CfkxJwN3lhzHMuAYSVMlzQA+UPL5Ui0FziG7rFlK1mW/O/8rFMDHgP0knVdyHHcAx+XjETPJfkl71SRg94i4HTiP7NbpjDJPmE1aK+wypxTjvmcSEeskXQTcKmkSsIXsmvtF4O3559aTjauUGcfyvOu+BniCrMv+XJnnTLQUuBC4KyJelLQxfw+AiOiXdBKwSNLzEdFy8tL2iIhVkhYC95D9eywv4zwF6QOuljSLrOf75Yh4ttxTin5U7ik6JNfN6R5JMyLiBUnTyP4Sz4+IVVXHZb3vgAOnxL/elPb83n5veHxlAbeG2zbueyY9ZkE+bXkq2TiOE4klCWBzj49KOJl0UUScXHUMVl8D0duXOU4mZjWQzYB1MjGzDgWiv8cvc3o7uhKN8PBTpXopnl6KBRxPw0AoaavKhE0mQE/9gNJb8fRSLOB4Bi9zUraq+DLHrBZEf/T23/7aJZNX7twXu72+r+N2dn1dH/sfOKXjSTaPPVjM2r1TJ+/ErGm7dT7pZ2vnM7qn9s1g1pRXdxxLbClmdvlUprGTdu44Hk2ZXEQ4TO2byaxXvKbjeH63ef1TEZH0tHoAW+j8575MtUsmu72+j6u/t2vVYQy64JiPVh3CEJOeeqbqEAZt/c8nqg5hiB12TVvdvVtu/u3/eyT12Aj3TMysIAO+NWxmncoGYN0zMbOO+TLHzAqQLUHgZGJmHQrE5vDdHDMrwECPX+b0dnRmBrw8AJuypUgoD3p5Xs1hdV4xoOXiT+6ZmNVAIPoLeu6mqTzou8kKcC2XtChf9zU7X8Snm47/FHBwq3bdMzGriQEmJW0JBsuDRsRmoFEedDQnAd9q1ah7JmY1EEE7t4ZnS1rRtL8gIhY07Y9UHvQdIzWUL76+F/CjVid1MjGrBbUzA7aI8qANJwI3RER/qwOdTMxqIIDNUdiva0p50IYTyao5tORkYlYDQaELHw2WByVLIieS1Y4aQtJ+wKuAu1IarTSZSHohIkotXmQ2XhT1bE5ieVDIksx1rcqCNrhnYlYDQbGT1lqVB833L22nza7dGpZ0iqSf55NgrszvdTcmx6yVtFhS0kIxZhNP2pKNVS7b2JVkIukPyMpvHhIRB5EVf/4IMJ2sW/VmssLYnxnl6+dLWiFpxTMbqqymalaNRs8kZatKty5z3gXMIZtpB1kV+fVkdZYX5sdcDXxnpC/O75EvAApZatGsjlw3JyOycpgXDHlTunjYcU4UZiOIEFsGenuIs1t9osXACZJeDSBp53xm3STghPyYk4E7uxSPWa1k65koaatKV1JdRKyTdBFwq6RJwBayiTAvAm/PP7eebFzFzLbhldYGRcRCXh4fafAcE7ME2QCsx0zMrABeUNrMOlbwdPpSOJmY1YQXlDazjkXAlgEnEzPrUHaZ42RiZgXwDFgz65hvDZtZQXyZY2YFqXKqfAonE7MayFandzIxsw4FYutAb9ca7u2LMDMbVORTw63Kg+bH/JmkdflKiNe2arN2PZPHHpzNBR/8aNVhDPrhrddVHcIQR+01Yi2lSkyaPr3qEIbo/8/1VYew3Yq8m5NSHlTSPsAFZKsjPtNYPmQstUsmZhNVgXdzBsuDAkhqlAdd13TMx4ErIuIZgIhomYl9mWNWB5E96JeykZcHbdrmD2ttpPKgrxt2zL7AvpKWSfqppKNaheieiVkNNFZaS1REedAdgH2Aw8gq/t0h6S0R8exYX2BmNVDgDNiU8qCPAj+LiC3Aw5J+RZZclo/WqC9zzGoggK0Dk5K2BIPlQSVNIavct2jYMd8l65UgaTbZZc9DYzXqnolZDRS5OFJiedBbgPdIWkdW5+rciHh6rHadTMxqosjp9K3Kg+b1hc/OtyROJmZ1EH5q2MwK4CUIzKwwTiZm1rFA9HsNWDMrgtczMbOOhQdgzawo4WRiZp1zRT8zK8iE7plIuhg4BXiS7JHnlcAHgLuBQ4HpwJ+TLcLyFmBhRFxUZkxmdTSh55lImgccD7wVmAysIksmAJsjYq6ks4AbgTnABuDXki4f/gxAvh7DfICpk3cqK2Sz3lWDBaXLvHF9CHBjRGyMiOeB7zV9rvGE4r3A2oh4PCI2kT2VuPuwdoiIBRExNyLmTtmht5YCNOuGILvMSdmqUtWYyab840DT68a+x3HMttH7A7Bl9kyWAcdImippBtlYiZltp4i0rSql9QIiYrmkRcAa4AmyS5rnyjqf2Xg3oe/mAJdFxKWSpgF3ACsj4qrGJyNiCbCkaf+wkuMxq6Ws1zGxk8kCSfsDU4FvRMSqks9nNm71+phJqckkIk4us32ziWRgoLeTSW8/02xmQLYEQZG3hluVB5V0qqQnJa3Ot79s1aZvw5rVRFE3alLKg+YWRsQnU9t1z8SsDqLQSWuD5UEjYjPQKA/aEScTs7qIxK21lPKgAMdLWiPpBknbzEwfzsnErCba6Jm0qjWc4nvAnhFxIHAb8I1WX9DWmImkScCMiPjddgRnZh1oY3Zrq1rDLcuDDnvY9p+BL7Q6acueiaRrJe0kaTpwH7BO0rmtvs7MihMBMTApaUvQsjyopNc27R4L3N+q0ZQz75/3RI4DfgjsBfy3lIjNrDhFPZsTEVuBRnnQ+4HrG+VBJR2bH3ampLWS7gHOBE5t1W7KZc5kSZPJksk/RMQWSRU+TmQ2QRX4W5dQHvQCskXLkqX0TK4EfkO2KtodkvYAPGZi1lXFTlorQ8ueSUR8Gfhy01uPSDq8vJBaxCOIvr6qTr+Now/p+PZ8oWLTI1WHMOj8dcurDmGILx7y7qpDGOrxNo/v8euBlAHYs/IBWEn6qqRVwBFdiM3MGoqdtFaKlMucj+UDsO8BXkU2+Pq5UqMys20VN2mtFCkDsI1U937gm/mob28/vmg2HvX4EgQpPZOVkm4lSya3SJpJtlarmXXTOOiZnAYcBDwUES9J+j3gL8oNy8yGCHq+Z5JyN2dA0sPAvpKmdiEmMxtBlYtFp2iZTPJFUc4im7+/GvhD4C58R8esu3o8maSMmZwFzAMeiYjDgYOBZ0uNysy2FUrbKpIyZrIxIjZKQtIrIuIXkt5UemRm9rIA9fhtj5Rk8qikVwLfBW6T9AzQO9MszSaEansdKVIGYD+Uv7xU0u3ALODmUqMys231+JjJqMlE0s4jvH1v/nEGsKGUiMxsZHVNJsBKsvCb+1aN/QB+v8S4zGy4uiaTiNirGwFIeiEiZnTjXGa1VYNJaylPDX9I0qym/VdKOq7csMxsOEXaVpWUeSafiYjnGjsR8SzwmdQTSDpX0pn568sl/Sh/fYSka5qOmy3pLklHp4dvNoH0+LM5KclkpGPaWdV+KXBo/nouMCNfBvJQ4A4ASa8BbgIuiYib2mjbbMIosmfSqjxo03HHSwpJY612D6QlkxWSviRp73z7EtngbKqVwBxJOwGbyKbizyVLJkuBycBi4H9GxG0jNSBpfqMGyJatL7VxarNxpKAZsE3lQd8H7A+cJGn/EY6bSTYD/mcp4aUkk08Bm4GFZGUENwJnpDQOEBFbgIfJVrf+CVkCORx4I9nK2FvJEs57x2hjQUTMjYi5k3eYlnpqs/Ej9RInrWeSWh70/wCfJ/udb6llMomIFyPi/PyXeV5E/K+IeDEp5JctBc4hu6xZCpwO3B0Rjf/9jwH7STqvzXbNJo4ulgeV9DZg93aGHbpVHnQp8Frgroh4gizTLW18MiL6gZOAIyR9oksxmdVKG2MmHZUHzSt3fgn463a+rq3yoNsrIhaTjY009vdtej0j/7iJMS51zCa87pUHnQkcACzJV2jdFVgk6diIWDFao11JJmbWGRX71PBgeVCyJHIicHLjk/lUkNmD55aWAOeMlUggbdLavpIWS7ov3z9Q0kXb9b9gZtuvoLs5ieVB25bSM7kKOJessh8RsUbStcBnt/ekZrYdulgedNj7h6W0mZJMpkXEz4dVt9ia0riZFafXK3ynJJOnJO1NnhclnUD7hQ3NrFPjIJmcASwgmwfyGNkEtFNKjcrMhqr4Ib4UKSutPQQcKWk6MCkini8/LDPbRt2TiaRLhu0DEBF/U1JMZjaC8bCgdPPU+anAB8huJ5mZDUq5zPm75n1Jl5Hdnzazbqr7Zc4IppFNvzWzbhkPA7CS7uXlnNgH7AJ4vMSs2+qeTMjGSBq2Ak/k03HNrJvqnEzyFZluiYj9uhSPmY1A9P5lzpgP+uXrjPxS0hu6FI+ZjSR/ajhlq0rKZc6rgLWSfk7TbeKI2O6nC81sO/R4zyQlmVxcehRm1to4SCbvj4gha7NK+jzw43JCMrOR1HrMJPfuEd57X9GBmFkLPV6Ea9SeiaT/AXwC+H1Ja5o+NRNYVnZgZtak4kSRYqzLnGuBHwL/F2iu+PV8RGwoNSoz20ZtH/TLF5V9jqwEhZlVbDyMmZhZLyhwzKRVrWFJp0u6V9JqSXeOVD50OCcTszoosDxoYq3hayPiLRFxEPAFsqJcY3IyMasBtbElaFlrOCJ+17Q7nYQ05SJcZnWRPmYyW1JzwawFEbGgaX+kWsPvGN6IpDOAs4EpwBGtTupkYlYTbQzAtioPmiQirgCukHQycBHw0bGO79pljqRLJZ3TrfOZjTsDiVtrrWoND3cdcFyrRj1mYlYH+UprKVuCwVrDkqaQ1Rpe1HyApH2ado8GHmjVaKnJRNKFkn4l6U7gTfl7SyRdLmmFpPslzZP0HUkPSHLJUbPRFHQ3J7HW8CclrZW0mmzcZMxLHChxzETSHLKMd1B+nlXAyvzTmyNirqSzgBuBOcAG4NeSLo+Ip4e1NR+YDzB1yqyyQjbraUVOWmtVazgizmq3zTJ7JocC/xYRL+W3mZq7UY3X9wJrI+LxiNgEPMTQazkAImJBRMyNiLmTd5hWYshmPayuD/qVbFP+caDpdWPfd5jMRjCRp9PfARwnaUdJM4FjSjyX2fhW4AzYspTWC4iIVZIWAvcA68lGkM1sO4gaPzVchIj4W+Bvh719WdPnlwBLmvYPKzMes1rr8cscj0+Y1YSit7OJk4lZHdR8pTUz6yG9fjfHycSsLpxMzKwI7pmYWedigt8aNrMCuWdiZp0Svswxs6J4nomZFcE9EzPrnCetmVlRfDenYNrST9/6Z6oO42WvmFJ1BEP07bJL1SEM+tz+86oOYYgFD/xr1SEMsdc2y4CNrdeTiReUNquDIBuATdkSJJQHPVvSOklrJC2WtEerNp1MzGqiqNXpE8uD3g3MjYgDgRvISoSOycnErC6KW2ktpTzo7RHxUr77U7LaOmNyMjGrgcaktcSeyey8lExjmz+suZHKg75ujNOfBvywVYy1G4A1m5DaGA+hoPKgAJJOAeYCf9LqWCcTs5oo8G5OUnlQSUcCFwJ/kpeiGZMvc8xqosvlQQ8GrgSOjYj1KY26Z2JWBwEMFDMFNiK2SmqUB+0DvtYoDwqsiIhFwBeBGcC3JQH8NiKOHbVRnEzM6qO75UGPbLdNJxOzmvCDfmZWDC9BYGZFcM/EzDqmABU0AFsWJxOzuvBTwxlJe0q6r1vnMxtvFJG0VcU9E7M6qMFKa92eAdsn6SpJayXdKmlHSWc2rZtwXZfjMauJxLVMJlDPZB/gpIj4uKTrgeOB84G9ImKTpFeO9EX5U4/zAab2zexasGa9pNfv5nS7Z/JwRKzOX68E9gTWANfkTyduHemLImJBRMyNiLlTJu3YnUjNek2P90y6nUyanzzsJ+sZHU226tPbgOWSPI5jNlyA+iNpq0rVTw1PAnaPiNuB84BZZA8Xmdlwxa20VoqqewF9wNWSZpEtJvXliHi24pjMelKVt31TdC2ZRMRvgAOa9i/r1rnNxgUnEzPrWNDzM2CdTMxqQFQ7uzWFk4lZXTiZmFnHAqjwtm+Kqm8Nm1miIh/0SygP+seSVknaKumElDadTMzqoqAZsInlQX8LnApcmxqeL3PMaqHQqfKD5UEB8gdsPwisGzxbNpUDKb1aj3smZnUQFPlsTrvlQZO4Z2JWF+nzTGZLWtG0vyAiFhQf0FBOJmY10cY8k1a1hpPKg7bLycSsDgLoL2wK7GB5ULIkciJwcqeNeszErBaKW2ktIrYCjfKg9wPXN8qDSjoWQNI8SY8CHwaulLS2VbuKHp9VN5ykJ4FHCmhqNvBUAe0UpZfi6aVYYPzGs0dE7JJy4Kypu8Yf7f7nSY3e/OAXV7a4zClF7S5zUr/5rUhaUcU3fDS9FE8vxQKOZ1CP/+GvXTIxm5ACcBEuM+tcQPT2GgQTOZmUft+9Tb0UTy/FAo6n6Ls5pZiwd3O6MYmnHUXGI+mF/ONukm5ocexfSZrWTiySDpP0/TbiWSJpu8cYxvO/VZsn9ur01rn84ay2RMR/RESrJz7/CpjW4hjrBU4mNpa8BvMvJF0j6X5JNzR6CpJ+I+nzklYBH5a0t6SbJa2UtFTSfvlxe0m6S9K9kj47rO378td9ki6TdF9ePfFTks4EdgNul3R7ftx78rZWSfq2pBn5+0flca4C/nSU/5dtzjHCMV+RtCKv6vi/m97/XFNlx8vy9z6ct3WPpDuK+Y7XlSv6WZo3AadFxDJJXwM+ATQW3H46It4GIGkxcHpEPCDpHcA/AkcAfw98JSL+RdIZo5xjPlnRs4MiYquknSNig6SzgcMj4ilJs4GLgCMj4kVJ5wFnS/oCcFV+rgeBhannGOGYC/Pz9gGLJR1INgvzQ8B+ERFNlR0vAd4bEY+NVu1xwghgwGMm1tq/R8Sy/PXVwH9p+txCgLyH8EfAtyWtBq4EXpsfcwjwrfz1N0c5x5HAlfnsRyJiwwjH/CHZ+hbL8nN8FNgD2I+sGuMDkc1yvLqDc/xZ3ru5G3hzfr7ngI3AVyX9KfBSfuwy4OuSPk5WFmVic8/EEgz/CWjefzH/OAl4NiIOSmxjewi4LSJOGvKmNNo522s8exbkHGBeRDwj6evA1LwX83bgXcAJZFO9j4iI0/Me2NHASklzIuLpImKppR6ftOaeSW94g6R35q9PBu4cfkBE/A54WNKHAZR5a/7pZWQPawF8ZJRz3Ab890b51aZLkOeBRjX4nwKHSHpjfsx0SfsCvwD2lLR3ftyQZJNwjoadyJLjc5JeQ7bSV6PXNSsifgB8Gnhr/v7eEfGziLgEeJKhT7pOLBFEf3/SVhUnk97wS+AMSfcDrwK+MspxHwFOk3QPsJZsdSyAs/Kvv5fRF7n5Z7Kl+NbkX994SnQBcLOk2yPiSbKl+r4laQ1wF9k4xkay8ZCb8kuU9W2eA4CIuIfs8uYXZMsBNi7tZgLfz895J3B2/v4X80Hl+4CfAPeMct6JYSDStorU7kG/8UbSnsD3I+KAFofaBDZrh13inTM/2PpA4JZnv+oH/cxsFBE9fzfHyaRiw2swm42qx68inEzMaiLcMzGzzlU7hySFk4lZHQRQ4W3fFE4mZjUQQHhxJDPrWHhxJDMrSK/3TDxpzawGJN1Mtip+iqci4qgy4xmJk4mZFcLP5phZIZxMzKwQTiZmVggnEzMrhJOJmRXCycTMCuFkYmaFcDIxs0I4mZhZIf4/FVuyrg+xoKQAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ "plt.matshow(CMsubsumedfrom2)\n", "plt.xticks(range(max(truth)+1), names)\n", @@ -604,7 +812,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -629,7 +837,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -639,9 +847,23 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "proclam implementation of log-loss: 1.9467692362663755\n", + "proclam implementation of log-loss: 0.5873053661481069\n", + "proclam implementation of log-loss: 0.4985535342418646\n", + "proclam implementation of log-loss: 3.2716260697307087\n", + "proclam implementation of log-loss: 0.8391912085925096\n", + "proclam implementation of log-loss: 0.5795189847083857\n", + "proclam implementation of log-loss: 4.44547996411847\n" + ] + } + ], "source": [ "for candidate in [predictionB, predictionC, predictionG, predictionH, predictionI, predictionJ, predictionJ1]:\n", " D = proclam.metrics.LogLoss()\n", @@ -658,9 +880,21 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tunnel logloss = 1.3831099142162573\n", + "broadbrush logloss = 2.147514136377367\n", + "cruise logloss = 1.8925640255539373\n", + "subsumedto2 logloss = 0.8697035708801626\n", + "subsumedfrom logloss = 1.2313487087147357\n" + ] + } + ], "source": [ "test_cases = {'tunnel': CMtunnel, 'broadbrush': CMbroadbrush, 'cruise': CMcruise, 'subsumedto2': CMsubsumedto2, 'subsumedfrom': CMsubsumedfrom2}\n", "LL_metric = proclam.metrics.LogLoss()\n", @@ -673,9 +907,21 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tunnel logloss = 0.45996840473970635\n", + "broadbrush logloss = 3.3173356388856186\n", + "cruise logloss = 0.7884312373010852\n", + "subsumedto logloss = 0.5772789861655064\n", + "subsumedfrom logloss = 4.401321122622816\n" + ] + } + ], "source": [ "test_cases = {'tunnel': CMtunnel, 'broadbrush': CMbroadbrush, 'cruise': CMcruise, 'subsumedto': CMsubsumedto2,'subsumedfrom': CMsubsumedfrom2}\n", "LL_metric = proclam.metrics.LogLoss()\n", @@ -697,11 +943,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 30, "metadata": { "scrolled": true }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "proclam implementation of Brier score: 0.17481523141779132\n", + "proclam implementation of Brier score: 0.03349662221027182\n", + "proclam implementation of Brier score: 0.13994060524323776\n", + "proclam implementation of Brier score: 0.16442263732364013\n", + "proclam implementation of Brier score: 0.19760600980918508\n", + "proclam implementation of Brier score: 0.08131350128672857\n", + "proclam implementation of Brier score: 0.08047203028278058\n" + ] + } + ], "source": [ "for candidate in [predictionB, predictionC, predictionG, predictionH, predictionI, predictionJ, predictionJ1]:\n", " E = proclam.metrics.Brier()\n", @@ -720,9 +980,21 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tunnel Brier = 0.13835879922919994\n", + "broadbrush Brier = 0.16647095421593264\n", + "cruise Brier = 0.19722890971661755\n", + "subsumedto Brier = 0.07719957610858942\n", + "subsumedfrom Brier = 0.0838596231599401\n" + ] + } + ], "source": [ "test_cases = {'tunnel': CMtunnel, 'broadbrush': CMbroadbrush, 'cruise': CMcruise, 'subsumedto': CMsubsumedto2, 'subsumedfrom': CMsubsumedfrom2}\n", "B_metric = proclam.metrics.Brier()\n", @@ -735,9 +1007,21 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tunnel Brier = 0.03829109430194667\n", + "broadbrush Brier = 0.2519580848801754\n", + "cruise Brier = 0.07399778110647401\n", + "subsumedto Brier = 0.0487604318197624\n", + "subsumedfrom Brier = 0.28825253455457156\n" + ] + } + ], "source": [ "test_cases = {'tunnel': CMtunnel, 'broadbrush': CMbroadbrush, 'cruise': CMcruise, 'subsumedto': CMsubsumedto2, 'subsumedfrom': CMsubsumedfrom2}\n", "B_metric = proclam.metrics.Brier()\n", @@ -752,49 +1036,157 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Deterministic metrics\n", + "## Deterministic metrics\n", "\n", "Let's check that reducing the probabilities to class point estimates actually does what we want; the one based on a good confusion matrix should do better than the random guesser. " ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### True/False positive/negative rates\n", + "\n", + "Let's compare `proclam`'s calculation of the standard rates to that of `pycm`." + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rates(TPR=array([0.98113208, 1. , 0.94117647, 0.95548961, 0.96296296]), FPR=array([0.00527983, 0.00740741, 0.01256281, 0.01357466, 0.01540832]), FNR=array([0.01886792, 0. , 0.05882353, 0.04451039, 0.03703704]), TNR=array([0.99472017, 0.99259259, 0.98743719, 0.98642534, 0.98459168]), TP=array([ 52., 55., 192., 322., 338.]), FP=array([ 5., 7., 10., 9., 10.]), FN=array([ 1., 0., 12., 15., 13.]), TN=array([942., 938., 786., 654., 639.]))\n" + ] + } + ], + "source": [ + "detC = proclam.metrics.util.prob_to_det(predictionC)\n", + "cmC = proclam.metrics.util.det_to_cm(detC, truth)\n", + "print(proclam.metrics.util.cm_to_rate(cmC, vb=True))" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rates(TPR={0: 0.9811320754716981, 1: 1.0, 2: 0.9411764705882353, 3: 0.9554896142433235, 4: 0.9629629629629629}, FPR={0: 0.0052798310454065245, 1: 0.007407407407407418, 2: 0.012562814070351758, 3: 0.013574660633484115, 4: 0.015408320493066285}, FNR={0: 0.018867924528301883, 1: 0.0, 2: 0.05882352941176472, 3: 0.04451038575667654, 4: 0.03703703703703709}, TNR={0: 0.9947201689545935, 1: 0.9925925925925926, 2: 0.9874371859296482, 3: 0.9864253393665159, 4: 0.9845916795069337}, TP={0: 52, 1: 55, 2: 192, 3: 322, 4: 338}, FP={0: 5, 1: 7, 2: 10, 3: 9, 4: 10}, FN={0: 1, 1: 0, 2: 12, 3: 15, 4: 13}, TN={0: 942, 1: 938, 2: 786, 3: 654, 4: 639})\n" + ] + } + ], + "source": [ + "compare = ConfusionMatrix(truth, detC)\n", + "print(proclam.metrics.util.RateMatrix(TPR=compare.TPR, FPR=compare.FPR, FNR=compare.FNR, TNR=compare.TNR, TP=compare.TP, FP=compare.FP, FN=compare.FN, TN=compare.TN))" + ] + }, + { + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "# from proclam.metrics import util as pmu\n", - "# from proclam.metrics import roc\n", - "# from importlib import reload\n", - "# reload(proclam.metrics.roc)\n", - "# from proclam.metrics import roc\n", + "### ROC AUC\n", "\n", + "The ROC is the true positive rate as a function of the false positive rate, where the rates are calculated from a confusion matrix derived by deterministically assigning classes based on a series of threshold values in probability." + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ROC AUC for prediction B = 0.493\n", + "ROC AUC for prediction C = 0.995\n", + "ROC AUC for prediction B with weird weights = 0.498\n" + ] + } + ], + "source": [ "ROC_metric = proclam.metrics.ROC()\n", - "rocB = ROC_metric.evaluate(predictionB,truth, 0.1)\n", + "rocB = ROC_metric.evaluate(predictionB,truth, 0.1, averaging='per_class', vb=False)\n", "rocC = ROC_metric.evaluate(predictionC,truth, 0.1)\n", - "\n", "print('ROC AUC for prediction B = %.3f'%rocB)\n", "print('ROC AUC for prediction C = %.3f'%rocC)\n", - "\n", "rocB = ROC_metric.evaluate(predictionB,truth,0.1, averaging=[0.1,0.3,0.2,0.2,0.2])\n", "print('ROC AUC for prediction B with weird weights = %.3f'%rocB)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's compare the ROC AUC calculated by `proclam` to that of `scikit-learn`. " + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ - "from proclam.metrics import precision_recall\n", - "reload(proclam.metrics.precision_recall)\n", - "reload(proclam.metrics.util)\n", - "from proclam.metrics import precision_recall\n", - "from proclam.metrics import util\n", + "# this runs if you change line 78 in roc.py to return auc_class instead of auc_allclass\n", + "# truth_score = proclam.metrics.util.det_to_prob(truth)\n", + "# from_proclam, proclam_auc, from_skl = [],[],[]\n", + "# for m in range(M_classes):\n", + "# fpr, tpr, thresholds = skl.metrics.roc_curve(truth_score.T[m].T.astype(int), predictionB.T[m].T)\n", + "# proclam_says = ROC_metric.evaluate(predictionB, truth, thresholds, vb=True)[m]\n", + "# proclam_auc.append(ROC_metric.evaluate(predictionB, truth, thresholds, vb=False)[m])\n", + "# i = np.argsort(proclam_says[0])\n", + "# new_auc = proclam.metrics.util.auc(proclam_says[0][i], proclam_says[1][i])\n", + "# print('proclam: '+str(new_auc))\n", + "# from_proclam.append(proclam.metrics.util.auc(proclam_says[0], proclam_says[1]))\n", + "# skl_says = proclam.metrics.util.auc(fpr, tpr)\n", + "# print('skl: '+str(skl_says))\n", + "# from_skl.append(skl_says)\n", + "# print('proclam says %.3f'%np.mean(proclam_auc)+' scikit-learn says %.3f'%np.mean(from_skl))\n", + "# print('proclam says '+str(proclam_auc)+' scikit-learn says '+str(from_skl))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### PRC AUC\n", "\n", - "metric = precision_recall.Metric()\n", - "prB = metric.evaluate(predictionB,truth)\n", - "prC = metric.evaluate(predictionC,truth)\n", + "The precision is the number of correctly classified positives divided by the number all items classified as positive, whereas the recall is the number of correctly classified positives divided by the number of items whose true class was positive. " + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/aimalz/Code/proclam/proclam/metrics/util.py:455: RuntimeWarning: invalid value encountered in double_scalars\n", + " p = np.asarray(TP / (TP + FP))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Precision/Recall AUC for prediction B = 0.203\n", + "Precision/Recall AUC for prediction C = 0.937\n" + ] + } + ], + "source": [ + "metric = proclam.metrics.PRC()\n", + "prB = metric.evaluate(predictionB,truth, 0.01)\n", + "prC = metric.evaluate(predictionC,truth, 0.01)\n", "\n", "print('Precision/Recall AUC for prediction B = %.3f'%prB)\n", "print('Precision/Recall AUC for prediction C = %.3f'%prC)" @@ -802,15 +1194,46 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ - "from proclam.metrics import fscore\n", - "reload(proclam.metrics.fscore)\n", - "from proclam.metrics import fscore\n", - "\n", - "metric = fscore.Metric()\n", + "# compare = ConfusionMatrix(truth, proclam.metrics.util.prob_to_det(predictionC))\n", + "# print((compare.TPR, compare.PPV))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### F-score" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "prC = metric.evaluate(predictionC,truth, 0.01, vb=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "f1 score for prediction B = 0.184\n", + "f1 score for prediction C = 0.953\n" + ] + } + ], + "source": [ + "metric = proclam.metrics.F1()\n", "fB = metric.evaluate(predictionB,truth)\n", "fC = metric.evaluate(predictionC,truth)\n", "\n", @@ -818,6 +1241,104 @@ "print('f1 score for prediction C = %.3f'%fC)" ] }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "# compare = ConfusionMatrix(truth, proclam.metrics.util.prob_to_det(predictionB))\n", + "# print((compare.F1))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Matthews Correlation Coefficient" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MCC for prediction B = 0.009\n", + "MCC for prediction C = 0.942\n" + ] + } + ], + "source": [ + "metric = proclam.metrics.MCC()\n", + "mccB = metric.evaluate(predictionB, truth)\n", + "mccC = metric.evaluate(predictionC,truth)\n", + "\n", + "print('MCC for prediction B = %.3f'%mccB)\n", + "print('MCC for prediction C = %.3f'%mccC)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "# compare = ConfusionMatrix(truth, proclam.metrics.util.prob_to_det(predictionC))\n", + "# print((compare.MCC))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy for prediction B = 0.207\n", + "Accuracy for prediction C = 0.968\n" + ] + } + ], + "source": [ + "metric = proclam.metrics.Accuracy()\n", + "accB = metric.evaluate(predictionB, truth)\n", + "accC = metric.evaluate(predictionC,truth)\n", + "\n", + "print('Accuracy for prediction B = %.3f'%accB)\n", + "print('Accuracy for prediction C = %.3f'%accC)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{0: 0.785, 1: 0.769, 2: 0.675, 3: 0.601, 4: 0.596}\n" + ] + } + ], + "source": [ + "# compare = ConfusionMatrix(truth, proclam.metrics.util.prob_to_det(predictionB))\n", + "# print((compare.ACC))" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/proclam/metrics/__init__.py b/proclam/metrics/__init__.py index b70a93e..97292d2 100644 --- a/proclam/metrics/__init__.py +++ b/proclam/metrics/__init__.py @@ -5,3 +5,7 @@ from .brier import * from .logloss import * from .roc import * +from .prc import * +from .f1 import * +from .mcc import * +from .accuracy import * \ No newline at end of file diff --git a/proclam/metrics/accuracy.py b/proclam/metrics/accuracy.py new file mode 100644 index 0000000..00f94c9 --- /dev/null +++ b/proclam/metrics/accuracy.py @@ -0,0 +1,57 @@ +""" +A class for accuracy +""" + +from __future__ import absolute_import +__all__ = ['Accuracy'] + +import numpy as np + +from .util import weight_sum, check_weights +from .util import prob_to_det, det_to_cm, cm_to_rate +from .metric import Metric + +class Accuracy(Metric): + + def __init__(self, scheme=None): + """ + An object that evaluates the accuracy + + Parameters + ---------- + scheme: string + the name of the metric + """ + super(Accuracy, self).__init__(scheme) + self.scheme = scheme + + def evaluate(self, prediction, truth, averaging='per_class'): + """ + Evaluates the accuracy + + Parameters + ---------- + prediction: numpy.ndarray, float + predicted class probabilities + truth: numpy.ndarray, int + true classes + averaging: string or numpy.ndarray, float + 'per_class' weights classes equally, other keywords possible, vector assumed to be class weights + + Returns + ------- + accuracy_all: float + value of the metric + """ + prediction, truth = np.asarray(prediction), np.asarray(truth) + (N, M) = np.shape(prediction) + + dets = prob_to_det(prediction) + cm = det_to_cm(dets, truth) + rates = cm_to_rate(cm) + accuracy = rates.TPR + + weights = check_weights(averaging, M, truth=truth) + accuracy_all = weight_sum(accuracy, weights) + + return accuracy_all \ No newline at end of file diff --git a/proclam/metrics/f1.py b/proclam/metrics/f1.py new file mode 100644 index 0000000..90c204a --- /dev/null +++ b/proclam/metrics/f1.py @@ -0,0 +1,63 @@ +""" +A class for the F1 score +""" + +from __future__ import absolute_import +__all__ = ['F1'] + +import numpy as np + +from .util import weight_sum, check_weights +from .util import prob_to_det, det_to_cm, cm_to_rate +from .util import precision +from .metric import Metric + +class F1(Metric): + + def __init__(self, scheme=None): + """ + An object that evaluates the F1 score + + Parameters + ---------- + scheme: string + the name of the metric + """ + super(F1, self).__init__(scheme) + self.scheme = scheme + + def evaluate(self, prediction, truth, averaging='per_class'): + """ + Evaluates the F1 score + + Parameters + ---------- + prediction: numpy.ndarray, float + predicted class probabilities + truth: numpy.ndarray, int + true classes + averaging: string or numpy.ndarray, float + 'per_class' weights classes equally, other keywords possible, vector assumed to be class weights + + Returns + ------- + f1_all: float + value of the metric + """ + prediction, truth = np.asarray(prediction), np.asarray(truth) + (N, M) = np.shape(prediction) + + for m in range(M): + if not len(np.where(truth == m)[0]): + raise RuntimeError('No true values for class %i so F1 is undefined'%m) + dets = prob_to_det(prediction) + cm = det_to_cm(dets, truth) + rates = cm_to_rate(cm) + r = rates.TPR + p = precision(rates.TP, rates.FP) + f1 = 2 * p * r / (p + r) + + weights = check_weights(averaging, M, truth=truth) + f1_all = weight_sum(f1, weights) + + return f1_all \ No newline at end of file diff --git a/proclam/metrics/fscore.py b/proclam/metrics/fscore.py deleted file mode 100644 index d635e0b..0000000 --- a/proclam/metrics/fscore.py +++ /dev/null @@ -1,55 +0,0 @@ -""" -A superclass for metrics -""" - -from __future__ import absolute_import -__all__ = ['Metric'] - -import numpy as np - -from .util import weight_sum -from .util import check_weights -from .util import prob_to_det_threshold -from .util import auc, precision, recall -from scipy.integrate import trapz -from sklearn.metrics import precision_recall_curve -from sklearn.metrics import f1_score - -class Metric(object): - - def __init__(self, scheme=None, **kwargs): - """ - An object that evaluates the F-score - - Parameters - ---------- - scheme: string - the name of the metric - """ - - self.debug = False - self.scheme = scheme - - def evaluate(self, prediction, truth, **kwds): - """ - Evaluates the area under the ROC curve for a given class_idx - - Parameters - ---------- - prediction: numpy.ndarray, float - predicted class probabilities - truth: numpy.ndarray, int - true classes - - Returns - ------- - metric: float - value of the metric - """ - - best_class = np.zeros(len(truth)) - for i in range(len(truth)): - best_class[i] = np.where(prediction[i,:] == np.max(prediction[i,:]))[0] - fscore = f1_score(truth,best_class,average='macro') - - return fscore diff --git a/proclam/metrics/mcc.py b/proclam/metrics/mcc.py new file mode 100644 index 0000000..0a926a6 --- /dev/null +++ b/proclam/metrics/mcc.py @@ -0,0 +1,67 @@ +""" +A class for the Matthews correlation coefficient +""" + +from __future__ import absolute_import +__all__ = ['MCC'] + +import numpy as np + +from .util import weight_sum, check_weights +from .util import prob_to_det, det_to_cm, cm_to_rate +from .metric import Metric + +class MCC(Metric): + + def __init__(self, scheme=None): + """ + An object that evaluates the Matthews correlation coefficient + + Parameters + ---------- + scheme: string + the name of the metric + """ + super(MCC, self).__init__(scheme) + self.scheme = scheme + + def evaluate(self, prediction, truth, averaging='per_class'): + """ + Evaluates the Matthews correlation coefficient + + Parameters + ---------- + prediction: numpy.ndarray, float + predicted class probabilities + truth: numpy.ndarray, int + true classes + averaging: string or numpy.ndarray, float + 'per_class' weights classes equally, other keywords possible, vector assumed to be class weights + + Returns + ------- + mcc_all: float + value of the metric + """ + prediction, truth = np.asarray(prediction), np.asarray(truth) + (N, M) = np.shape(prediction) + + dets = prob_to_det(prediction) + cm = det_to_cm(dets, truth) + rates = cm_to_rate(cm) + + mcc = np.empty(M) + for m in range(M): + if not len(np.where(truth == m)[0]): + raise RuntimeError('No true values for class %i so MCC is undefined'%m) + num = rates.TP[m] * rates.TN[m] - rates.FP[m] * rates.FN[m] + A = rates.TP[m] + rates.FP[m] + B = rates.TP[m] + rates.FN[m] + C = rates.TN[m] + rates.FP[m] + D = rates.TN[m] + rates.FN[m] + mcc[m] = num / np.sqrt(A * B * C * D) + + weights = check_weights(averaging, M, truth=truth) + mcc_all = weight_sum(mcc, weights) + + return mcc_all \ No newline at end of file diff --git a/proclam/metrics/prc.py b/proclam/metrics/prc.py new file mode 100644 index 0000000..5b4f1de --- /dev/null +++ b/proclam/metrics/prc.py @@ -0,0 +1,82 @@ +""" +A class for the Precision-Recall Curve +""" + +from __future__ import absolute_import +__all__ = ['PRC'] + +import numpy as np + +from .util import weight_sum, check_weights +from .util import prob_to_det, det_to_cm, cm_to_rate +from .util import auc, check_auc_grid, precision +from .metric import Metric + +class PRC(Metric): + + def __init__(self, scheme=None): + """ + An object that evaluates the PRC AUC + + Parameters + ---------- + scheme: string + the name of the metric + """ + super(PRC, self).__init__(scheme) + self.scheme = scheme + + def evaluate(self, prediction, truth, grid, averaging='per_class', vb=False): + """ + Evaluates the area under the PRC + + Parameters + ---------- + prediction: numpy.ndarray, float + predicted class probabilities + truth: numpy.ndarray, int + true classes + grid: numpy.ndarray, float or float or int + array of values between 0 and 1 at which to evaluate ROC + averaging: string or numpy.ndarray, float + 'per_class' weights classes equally, other keywords possible, vector assumed to be class weights + + Returns + ------- + auc_allclass: float + value of the metric + """ + thresholds_grid = check_auc_grid(grid) + n_thresholds = len(thresholds_grid) + + prediction, truth = np.asarray(prediction), np.asarray(truth) + (N, M) = np.shape(prediction) + + auc_class = np.empty(M) + curve = np.empty((M, 2, n_thresholds)) + + for m in range(M): + m_truth = (truth == m).astype(int) + + if not len(np.where(truth == m)[0]): + raise RuntimeError('No true values for class %i so PRC is undefined'%m) + + precisions, recalls = np.empty(n_thresholds), np.empty(n_thresholds) + for i, t in enumerate(thresholds_grid): + dets = prob_to_det(prediction, m, threshold=t) + cm = det_to_cm(dets, m_truth) + rates = cm_to_rate(cm) + recalls[i] = rates.TPR[-1] + precisions[i] = precision(rates.TP[-1], rates.FP[-1]) + + (curve[m][0], curve[m][1]) = (recalls, precisions) + auc_class[m] = auc(recalls, precisions) + if np.any(np.isnan(curve)): + print('Where did these NaNs come from?') + return curve + + weights = check_weights(averaging, M, truth=truth) + auc_allclass = weight_sum(auc_class, weights) + + if vb: return curve + else: return auc_allclass \ No newline at end of file diff --git a/proclam/metrics/precision_recall.py b/proclam/metrics/precision_recall.py deleted file mode 100644 index 9c52dc1..0000000 --- a/proclam/metrics/precision_recall.py +++ /dev/null @@ -1,75 +0,0 @@ -""" -A superclass for metrics -""" - -from __future__ import absolute_import -__all__ = ['Metric'] - -import numpy as np - -from .util import weight_sum -from .util import check_weights -from .util import prob_to_det_threshold -from .util import auc, precision, recall -from scipy.integrate import trapz -from sklearn.metrics import precision_recall_curve - -class Metric(object): - - def __init__(self, scheme=None, **kwargs): - """ - An object that evaluates the F-score - - Parameters - ---------- - scheme: string - the name of the metric - """ - - self.debug = False - self.scheme = scheme - - def evaluate(self, prediction, truth, gridspace=0.01, weights=None, **kwds): - """ - Evaluates the area under the ROC curve for a given class_idx - - Parameters - ---------- - prediction: numpy.ndarray, float - predicted class probabilities - truth: numpy.ndarray, int - true classes - weights: numpy.ndarray, float - per-class weights - - Returns - ------- - metric: float - value of the metric - """ - - auc_allclass = 0 - n_class = np.shape(prediction)[1] - if not weights: - weights = [1./n_class]*n_class - - for class_idx in range(n_class): - if not len(np.where(truth == class_idx)[0]): - raise RuntimeError('No true values for class %i so ROC is undefined'%class_idx) - - truth_bool = np.zeros(len(truth),dtype=bool) - truth_bool[truth == class_idx] = 1 - P,R,thresholds_grid = precision_recall_curve(truth_bool,prediction[:,class_idx]) - - auc_class = auc(R,P) - - if self.debug: - import pylab as plt - plt.clf() - plt.plot(R,P) - plt.show() - import pdb; pdb.set_trace() - - auc_allclass += auc_class*weights[class_idx] - - return auc_allclass diff --git a/proclam/metrics/roc.py b/proclam/metrics/roc.py index ac510d3..bd24111 100644 --- a/proclam/metrics/roc.py +++ b/proclam/metrics/roc.py @@ -7,72 +7,73 @@ import numpy as np -from .util import weight_sum -from .util import check_weights -from .util import check_auc_grid -from .util import prob_to_det -from .util import det_to_cm -from .util import cm_to_rate -from .util import auc +from .util import weight_sum, check_weights +from .util import prob_to_det, det_to_cm, cm_to_rate +from .util import auc, check_auc_grid, prep_curve from .metric import Metric class ROC(Metric): - def __init__(self, scheme=None): - """ - An object that evaluates the ROC metric - - Parameters - ---------- - scheme: string - the name of the metric - """ - super(ROC, self).__init__(scheme) - self.scheme = scheme - - def evaluate(self, prediction, truth, grid, averaging='per_class'): - """ - Evaluates the area under the ROC curve for a given class_idx - - Parameters - ---------- - prediction: numpy.ndarray, float - predicted class probabilities - truth: numpy.ndarray, int - true classes - grid: numpy.ndarray, float or float or int - array of values between 0 and 1 at which to evaluate ROC - averaging: string or numpy.ndarray, float - 'per_class' weights classes equally, other keywords possible, vector assumed to be class weights - - Returns - ------- - metric: float - value of the metric - """ - thresholds_grid = check_auc_grid(grid) - n_thresholds = len(thresholds_grid) - - prediction, truth = np.asarray(prediction), np.asarray(truth) - (N, M) = np.shape(prediction) - - auc_class = np.empty(M) - - for m in range(M): - if not len(np.where(truth == m)[0]): - raise RuntimeError('No true values for class %i so ROC is undefined'%m) - - tpr, fpr = np.zeros(n_thresholds), np.zeros(n_thresholds) - for i, t in enumerate(thresholds_grid): - dets = prob_to_det(prediction, m, threshold=t) - cm = det_to_cm(dets, truth) - rates = cm_to_rate(cm) - tpr[i] = rates.TPR[m] - fpr[i] = rates.FPR[m] - - auc_class[m] = auc(fpr, tpr) - - weights = check_weights(averaging, M, truth=truth) - auc_allclass = weight_sum(auc_class, weights) - - return auc_allclass + def __init__(self, scheme=None): + """ + An object that evaluates the ROC metric + + Parameters + ---------- + scheme: string + the name of the metric + """ + super(ROC, self).__init__(scheme) + self.scheme = scheme + + def evaluate(self, prediction, truth, grid, averaging='per_class', vb=False): + """ + Evaluates the ROC AUC + + Parameters + ---------- + prediction: numpy.ndarray, float + predicted class probabilities + truth: numpy.ndarray, int + true classes + grid: numpy.ndarray, float or float or int + array of values between 0 and 1 at which to evaluate ROC + averaging: string or numpy.ndarray, float + 'per_class' weights classes equally, other keywords possible, vector assumed to be class weights + + Returns + ------- + auc_allclass: float + value of the metric + """ + thresholds_grid = check_auc_grid(grid) + n_thresholds = len(thresholds_grid) + + prediction, truth = np.asarray(prediction), np.asarray(truth) + (N, M) = np.shape(prediction) + + auc_class = np.empty(M) + curve = np.empty((M, 2, n_thresholds)) + + for m in range(M): + m_truth = (truth == m).astype(int) + + if not len(np.where(truth == m)[0]): + raise RuntimeError('No true values for class %i so ROC is undefined'%m) + + tpr, fpr = np.empty(n_thresholds), np.empty(n_thresholds) + for i, t in enumerate(thresholds_grid): + dets = prob_to_det(prediction, m, threshold=t) + cm = det_to_cm(dets, m_truth) + rates = cm_to_rate(cm) + fpr[i], tpr[i] = rates.FPR[-1], rates.TPR[-1] + + (curve[m][0], curve[m][1]) = (fpr, tpr) + (fpr, tpr) = prep_curve(fpr, tpr) + auc_class[m] = auc(fpr, tpr) + + weights = check_weights(averaging, M, truth=truth) + auc_allclass = weight_sum(auc_class, weights) + + if vb: return curve + else: return auc_allclass diff --git a/proclam/metrics/util.py b/proclam/metrics/util.py index 73b6b6b..ae4b81a 100644 --- a/proclam/metrics/util.py +++ b/proclam/metrics/util.py @@ -5,8 +5,8 @@ from __future__ import absolute_import, division __all__ = ['sanitize_predictions', 'weight_sum', 'check_weights', 'averager', - 'cm_to_rate', - 'auc', 'check_auc_grid', + 'cm_to_rate', 'precision', + 'auc', 'check_auc_grid', 'prep_curve', 'det_to_prob', 'prob_to_det', 'det_to_cm'] @@ -16,7 +16,7 @@ import sys from scipy.integrate import trapz -RateMatrix = collections.namedtuple('rates', 'TPR FPR FNR TNR') +RateMatrix = collections.namedtuple('rates', 'TPR FPR FNR TNR TP FP FN TN') def sanitize_predictions(predictions, epsilon=1.e-8): """ @@ -43,17 +43,16 @@ def sanitize_predictions(predictions, epsilon=1.e-8): predictions = predictions / np.sum(predictions, axis=1)[:, np.newaxis] return predictions -def weight_sum(per_class_metrics, weight_vector, norm=True): +def weight_sum(per_class_metrics, weight_vector): """ Calculates the weighted metric Parameters ---------- - per_class_metrics: numpy.float - the scores separated by class (a list of arrays) - weight_vector: numpy.ndarray floar - The array of weights per class - norm: boolean, optional + per_class_metrics: numpy.ndarray, float + vector of per-class scores + weight_vector: numpy.ndarray, float + vector of per-class weights Returns ------- @@ -61,7 +60,6 @@ def weight_sum(per_class_metrics, weight_vector, norm=True): The weighted metric """ weight_sum = np.dot(weight_vector, per_class_metrics) - return weight_sum def check_weights(avg_info, M, chosen=None, truth=None): @@ -112,6 +110,7 @@ def check_weights(avg_info, M, chosen=None, truth=None): weights[chosen] = 1./np.float(M) else: print('something has gone wrong with avg_info '+str(avg_info)) + weights = None return weights def averager(per_object_metrics, truth, M, vb=False): @@ -157,10 +156,11 @@ def cm_to_rate(cm, vb=False): ----- This can be done with a mask to weight the classes differently here. """ + cm = cm.astype(float) # if vb: print('by request cm '+str(cm)) tot = np.sum(cm) - tra = np.trace(cm) - # if vb: print('by request sum, trace '+str((tot, tra))) + # mask = range(len(cm)) + # if vb: print('by request sum '+str(tot)) T = np.sum(cm, axis=1) F = tot[np.newaxis] - T @@ -170,32 +170,52 @@ def cm_to_rate(cm, vb=False): TP = np.diag(cm) FN = P - TP - FP = T - TP#np.sum(cm - np.diag(cm)[:,np.newaxis], axis=0)# np.sum(np.tril(cm), 1), axis=1) - TN = F - FN#np.sum(cm - np.diag(cm)[np.newaxis], axis=1)# np.sum(np.triu(cm, 1), axis=0) + TN = F - FN + FP = T - TP # if vb: print('by request TP, FP, FN, TN'+str((TP, FP, FN, TN))) - # P = TP + FP - # N = TN + FN TPR = TP / P FPR = FP / N FNR = FN / P TNR = TN / N # if vb: print('by request TPR, FPR, FNR, TNR'+str((TPR, FPR, FNR, TNR))) - rates = RateMatrix(TPR=TPR, FPR=FPR, FNR=FNR, TNR=TNR) + rates = RateMatrix(TPR=TPR, FPR=FPR, FNR=FNR, TNR=TNR, TP=TP, FN=FN, TN=TN, FP=FP) # if vb: print('by request TPR, FPR, FNR, TNR '+str(rates)) return rates +def prep_curve(x, y): + """ + Makes a curve for AUC + + Parameters + ---------- + x: numpy.ndarray, float + x-axis + y: numpy.ndarray, float + y-axis + + Returns + ------- + x: numpy.ndarray, float + x-axis + y: numpy.ndarray, float + y-axis + """ + x = np.concatenate(([0.], x, [1.]),) + y = np.concatenate(([0.], y, [1.]),) + return (x, y) + def auc(x, y): """ Computes the area under curve (just a wrapper for trapezoid rule) Parameters ---------- - x: numpy.ndarray, int or float + x: numpy.ndarray, float x-axis - y: numpy.ndarray, int or float + y: numpy.ndarray, float y-axis Returns @@ -203,8 +223,6 @@ def auc(x, y): auc: float the area under the curve """ - x = np.concatenate(([0.], x, [1.]),) - y = np.concatenate(([0.], y, [1.]),) i = np.argsort(x) auc = trapz(y[i], x[i]) return auc @@ -224,7 +242,7 @@ def check_auc_grid(grid): grid of thresholds """ if type(grid) == list or type(grid) == np.ndarray: - thresholds_grid = np.array(grid) + thresholds_grid = np.concatenate((np.zeros(1), np.array(grid), np.ones(1))) elif type(grid) == float: if grid > 0. and grid < 1.: thresholds_grid = np.arange(0., 1., grid) @@ -237,7 +255,7 @@ def check_auc_grid(grid): thresholds_grid = None try: assert thresholds_grid is not None - return thresholds_grid + return np.sort(thresholds_grid) except AssertionError: print('Please specify a grid, spacing, or density for this AUC metric.') return @@ -302,12 +320,12 @@ class relative to binary decision assert(type(m) == int and type(threshold) == np.float64) except AssertionError: print(str(m)+' is '+str(type(m))+' and must be int; '+str(threshold)+' is '+str(type(threshold))+' and must be float') - dets = np.zeros(np.shape(probs)[0]) + dets = np.zeros(np.shape(probs)[0]).astype(int) dets[probs[:, m] >= threshold] = 1 return dets -def det_to_cm(dets, truth, per_class_norm=True, vb=False): +def det_to_cm(dets, truth, per_class_norm=False, vb=False): """ Converts deterministic classifications and truth into confusion matrix @@ -338,57 +356,47 @@ def det_to_cm(dets, truth, per_class_norm=True, vb=False): M = np.int(max(max(pred_classes), max(true_classes)) + 1) # if vb: print('by request '+str((np.shape(dets), np.shape(truth)), M)) - cm = np.zeros((M, M), dtype=float) + cm = np.zeros((M, M), dtype=int) coords = np.array(list(zip(dets, truth))) indices, index_counts = np.unique(coords, axis=0, return_counts=True) if vb: print(indices.T, index_counts) index_counts = index_counts.astype(int) indices = indices.T.astype(int) - # if vb: print('by request '+str(index_counts)) - # if vb: print(indices, index_counts) - # indices = indices.T - # if vb: print(indices) - # if vb: print(np.shape(indices)) cm[indices[0], indices[1]] = index_counts - # if vb: print(cm) if per_class_norm: - # print(type(cm)) - # print(type(true_counts)) - # cm = cm / true_counts - # cm /= true_counts[:, np.newaxis] # - cm = cm / true_counts[np.newaxis, :] + cm = cm.astype(float) / true_counts[np.newaxis, :].astype(float) - # if vb: print('by request '+str(cm)) + if vb: print('by request '+str(cm)) return cm # def prob_to_cm(probs, truth, per_class_norm=True, vb=False): - """ - Turns probabilistic classifications into confusion matrix by taking maximum probability as deterministic class - - Parameters - ---------- - probs: numpy.ndarray, float - N * M matrix of class probabilities - truth: numpy.ndarray, int - N-dimensional vector of true classes - per_class_norm: boolean, optional - equal weight per class if True, equal weight per object if False - vb: boolean, optional - if True, print cm - - Returns - ------- - cm: numpy.ndarray, int - confusion matrix - """ - dets = prob_to_det(probs) - - cm = det_to_cm(dets, truth, per_class_norm=per_class_norm, vb=vb) - - return cm +# """ +# Turns probabilistic classifications into confusion matrix by taking maximum probability as deterministic class +# +# Parameters +# ---------- +# probs: numpy.ndarray, float +# N * M matrix of class probabilities +# truth: numpy.ndarray, int +# N-dimensional vector of true classes +# per_class_norm: boolean, optional +# equal weight per class if True, equal weight per object if False +# vb: boolean, optional +# if True, print cm +# +# Returns +# ------- +# cm: numpy.ndarray, int +# confusion matrix +# """ +# dets = prob_to_det(probs) +# +# cm = det_to_cm(dets, truth, per_class_norm=per_class_norm, vb=vb) +# +# return cm #def cm_to_rate(cm, vb=False): # """ @@ -633,21 +641,26 @@ def auc(x, y): # """ # return 1. - rates.FNR -# def precision(rates): -# """ -# Calculates precision from rates -# -# Parameters -# ---------- -# rates: namedtuple -# named tuple of 'TPR FPR FNR TNR' -# -# Returns -# ------- -# precision: float -# precision -# """ -# return 1. - rates.FNR +def precision(TP, FP): + """ + Calculates precision from rates + + Parameters + ---------- + TP: float + number of true positives + FP: float + number of false positives + + Returns + ------- + p: float + precision + """ + p = np.asarray(TP / (TP + FP)) + if np.any(np.isnan(p)): + p[np.isnan(p)] = 0. + return p # # def recall(classifications,truth,class_idx): #