diff --git a/examples/model_selection/RF_vs_SPORF_random_search.ipynb b/examples/model_selection/RF_vs_SPORF_random_search.ipynb new file mode 100644 index 0000000000000..e27c3c20b64fa --- /dev/null +++ b/examples/model_selection/RF_vs_SPORF_random_search.ipynb @@ -0,0 +1,326 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Demonstration of randomized search to compare classifier performance\n", + "============================================================================\n", + "An important step in classifier performance comparison is hyperparameter \n", + "optimization. Here, we specify the classifer models we want to tune and a \n", + "dictionary of hyperparameter ranges (preferably similar for fairness in \n", + "comparision) for each classifier. Then, we find the optimal hyperparameters \n", + "through a function that uses RandomizedSearchCV and refit the optimized \n", + "models to obtain accuracies. We can see clearly in the plot that the \n", + "optimized models perform better than or similar to the default parameter models. On the \n", + "dataset we use in this example, car dataset from OpenML-CC18, SPORF also \n", + "performs better than RF overall. \n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Automatically created module for IPython interactive environment\n" + ] + } + ], + "source": [ + "print(__doc__)\n", + "\n", + "from sklearn.model_selection import RandomizedSearchCV\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "import math\n", + "from rerf.rerfClassifier import rerfClassifier\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.datasets import fetch_openml\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn import metrics\n", + "\n", + "from warnings import simplefilter\n", + "simplefilter(action=\"ignore\", category=FutureWarning)\n", + "from warnings import simplefilter\n", + "simplefilter(action=\"ignore\", category=FutureWarning)\n", + "\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def hyperparameter_optimization_random(X, y, *argv):\n", + " \"\"\"\n", + " Given a classifier and a dictionary of hyperparameters, find optimal hyperparameters using RandomizedSearchCV.\n", + "\n", + " Parameters\n", + " ----------\n", + " X : numpy.ndarray\n", + " Input data, shape (n_samples, n_features)\n", + " y : numpy.ndarray\n", + " Output data, shape (n_samples, n_outputs)\n", + " *argv : list of tuples (classifier, hyperparameters)\n", + " List of (classifier, hyperparameters) tuples:\n", + "\n", + " classifier : sklearn-compliant classifier\n", + " For example sklearn.ensemble.RandomForestRegressor, rerf.rerfClassifier, etc\n", + " hyperparameters : dictionary of hyperparameter ranges\n", + " See https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html.\n", + "\n", + " Returns\n", + " -------\n", + " clf_best_params : dictionary\n", + " Dictionary of best hyperparameters\n", + " \"\"\"\n", + "\n", + " clf_best_params = {}\n", + "\n", + " # Iterate over all (classifier, hyperparameters) pairs\n", + " for clf, params in argv:\n", + "\n", + " # Run randomized search\n", + " n_iter_search = 10\n", + " random_search = RandomizedSearchCV(\n", + " clf, param_distributions=params, n_iter=n_iter_search, cv=10, iid=False\n", + " )\n", + " random_search.fit(X, y)\n", + "\n", + " # Save results\n", + " clf_best_params[clf] = random_search.best_params_\n", + "\n", + " return clf_best_params" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Building classifiers and specifying parameter ranges to sample from\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# get some data\n", + "X, y = fetch_openml(data_id=40975, return_X_y=True, as_frame=True) #car dataset\n", + "y = pd.factorize(y)[0]\n", + "X = X.apply(lambda x: pd.factorize(x)[0])\n", + "n_features = np.shape(X)[1]\n", + "n_samples = np.shape(X)[0]\n", + "\n", + "# build a classifier\n", + "rerf = rerfClassifier()\n", + "\n", + "# specify max_depth and min_sample_splits ranges\n", + "max_depth_array_rerf = (np.unique(np.round((np.linspace(2, n_samples, 10))))).astype(\n", + " int\n", + ")\n", + "max_depth_range_rerf = np.append(max_depth_array_rerf, None)\n", + "\n", + "min_sample_splits_range_rerf = (\n", + " np.unique(\n", + " np.round((np.arange(1, math.log(n_samples), (math.log(n_samples) - 2) / 10)))\n", + " )\n", + ").astype(int)\n", + "\n", + "# specify parameters and distributions to sample from\n", + "rerf_param_dict = {\n", + " \"n_estimators\": np.arange(50, 550, 50),\n", + " \"max_depth\": max_depth_range_rerf,\n", + " \"min_samples_split\": min_sample_splits_range_rerf,\n", + " \"feature_combinations\": [1, 2, 3, 4, 5],\n", + " \"max_features\": [\"sqrt\", \"log2\", None, n_features ** 2],\n", + "}\n", + "\n", + "# build another classifier\n", + "rf = RandomForestClassifier()\n", + "\n", + "# specify max_depth and min_sample_splits ranges\n", + "max_depth_array_rf = (np.unique(np.round((np.linspace(2, n_samples, 10))))).astype(int)\n", + "max_depth_range_rf = np.append(max_depth_array_rf, None)\n", + "\n", + "min_sample_splits_range_rf = (\n", + " np.unique(\n", + " np.round((np.arange(2, math.log(n_samples), (math.log(n_samples) - 2) / 10)))\n", + " )\n", + ").astype(int)\n", + "\n", + "# specify parameters and distributions to sample from\n", + "rf_param_dict = {\n", + " \"n_estimators\": np.arange(50, 550, 50),\n", + " \"max_depth\": max_depth_range_rf,\n", + " \"min_samples_split\": min_sample_splits_range_rf,\n", + " \"max_features\": [\"sqrt\", \"log2\", None],\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Obtaining best parameters dictionary and refitting" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{rerfClassifier(feature_combinations=1.5, image_height=None, image_width=None,\n", + " max_depth=None, max_features='auto', min_samples_split=1,\n", + " n_estimators=500, n_jobs=None, oob_score=False,\n", + " patch_height_max=None, patch_height_min=1, patch_width_max=None,\n", + " patch_width_min=1, projection_matrix='RerF', random_state=None): {'n_estimators': 300, 'min_samples_split': 4, 'max_features': 36, 'max_depth': 577, 'feature_combinations': 2}, RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,\n", + " criterion='gini', max_depth=None, max_features='auto',\n", + " max_leaf_nodes=None, max_samples=None,\n", + " min_impurity_decrease=0.0, min_impurity_split=None,\n", + " min_samples_leaf=1, min_samples_split=2,\n", + " min_weight_fraction_leaf=0.0, n_estimators=100,\n", + " n_jobs=None, oob_score=False, random_state=None,\n", + " verbose=0, warm_start=False): {'n_estimators': 200, 'min_samples_split': 6, 'max_features': None, 'max_depth': 1153}}\n" + ] + } + ], + "source": [ + "best_params = hyperparameter_optimization_random(\n", + " X, y, (rerf, rerf_param_dict), (rf, rf_param_dict)\n", + ")\n", + "print(best_params)\n", + "\n", + "# extract values from dict - seperate each classifier's param dict\n", + "keys, values = zip(*best_params.items())\n", + "\n", + "# train test split\n", + "X_train, X_test, y_train, y_test = train_test_split(\n", + " X, y, test_size=0.33, random_state=42\n", + ")\n", + "\n", + "# get accuracies of optimized and default models\n", + "rerf_opti = rerfClassifier(**values[0])\n", + "rerf_opti.fit(X_train, y_train)\n", + "rerf_pred_opti = rerf_opti.predict(X_test)\n", + "rerf_acc_opti = metrics.accuracy_score(y_test, rerf_pred_opti)\n", + "\n", + "rerf_default = rerfClassifier()\n", + "rerf_default.fit(X_train, y_train)\n", + "rerf_pred_default = rerf_default.predict(X_test)\n", + "rerf_acc_default = metrics.accuracy_score(y_test, rerf_pred_default)\n", + "\n", + "rf_opti = RandomForestClassifier(**values[1])\n", + "rf_opti.fit(X_train, y_train)\n", + "rf_pred_opti = rf_opti.predict(X_test)\n", + "rf_acc_opti = metrics.accuracy_score(y_test, rf_pred_opti)\n", + "\n", + "rf_default = RandomForestClassifier()\n", + "rf_default.fit(X_train, y_train)\n", + "rf_pred_default = rf_default.predict(X_test)\n", + "rf_acc_default = metrics.accuracy_score(y_test, rf_pred_default)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Plotting the result" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAa4AAAEYCAYAAAAEZhLyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deZgV1bX38e9iEhEURDSBVsEgCCqigmgc0DggqIgQFTEKepXXiNE4cNWYoGKMJJIY74UkDteLOIZoJBgxykVRk2gUB5AhMiMNioCAgCLTev/Yu9vq06e7D9CH7ur+fZ6nnz5Vtatq1bhq12jujoiISFrUqeoAREREtocSl4iIpIoSl4iIpIoSl4iIpIoSl4iIpIoSl4iIpIoS104ws/3M7HUzW2dmv96F451pZifvYL8vmtnASo7nDjN7vDKHmTF8M7P/NbPVZvb2Tg6rtZm5mdWrrPhqKjNbZGanVXUcOyou57Y5lDvZzAp3RUxSOXYqcZnZlLgz2a2yAkqZwcBKYE93vzFbATP7rpm9EpPbWjN73sw65joCMxtjZj9PtnP3Q919yo4E7O493f3RHel3e5jZADN7MpEo1se/5Wb2VzM7fTsGdwJwOlDg7sdUcpxTzOyKCsr8h5n9Oy7D5Wb2gpk1id3GmNmmOG2fm9kkMzsk0W9HM5sQl/06M3vVzL6b6J45fxaZ2S0Z419kZl8lyqw3s5aVOR+2Vw7TPcjMtmbEPKqMYU2J8+CIjPbjY/uT8zw5NVqchxviMlhlZpPN7MLt6H+XJPbtGc8OJy4zaw2cCDjQe0eHs4Pjri5HywcCs7yMp7jN7DjgZeAvQEugDTAN+IeZHbTLoqwavYCJieam7t4YOAKYBDxnZoNyHNaBwCJ331C5IVbMzLoDvwAucvcmQAdgXEaxX8VpKwA+A8bEfr8D/AP4kLDsWwLPAS/HdSOpaP58H/hZlsR+jrs3Tvwtq7SJ3HFF090KWAr8T0b3NzNivqacYc0BLi1qMLPmwLHAisoOuqaqYL94RFxW7Qnr5ygzu32XBJYP7r5Df8Awwkb5G+CvGd12B34NLAbWAn8Hdo/dTgD+CawBlgCDYvspwBWJYQwC/p5odmAIMBdYGNvdH4fxBfAucGKifF3gJ8B8YF3svj8wGvh1RrzPAz8uYzq/C7wTp+Md4Lux/RhgM7AJWA+clqXfN4DfZWn/IjA2/j4ZKIyxrgQWARfHboMzxvF8bL+oaHzAHcCfgMfjdH4ItANuJexElwBnJMZdPJ8JSXR94s+Bk2O3YxPLaVpR+9itDfBaHN8kYBTweKJ7HWA5sA/QOg63XsY8uCmWqRObWwLPEnZUC4FrY/v/ADYCW2OMdwLNgL/Gsqvj74LEsIvnT2IePR5/F8cD3B2HuzEOe1SWZXUTML6c7WAM8PNE81nA+vj7MWBiln5+D7yeGU+i+9vA0LKmp5xYKpovU4C7CNvtOsJB1T6J7pcQttlVwG3ljTfLdPcCNpS1/VYQ9xTC/qQQqBvbXRPnUyHfrJO7Ab8FlsW/3wK7JYYzFPgkdrs8zte2iX5HAh8T1rs/8M0+6WSgMDGcmwmJeB3wEXBqGXHvBYyN83sx8FO+WZ8HEfZ7I+OyWAj0LGce7A/8OQ5rVdG6CHwHeCW2Wwk8QTjISa4bNwPTga/J2M4S+862Ge2+T1jvm8fmy4DZcZoXAP8vtt8D+ArYxjf7iZbAMcCbhP3DJ4R9QIPYjwH3EfY/a2Nsh5W3HMoaT5nzK5cVq4wZPQ+4GjiasHPdL9FtNGFlbEVIIN+NAR8QZ8xFQH2gOdA5c4eabcWPM38SsHdihftBHEY94EbgU6BhYiX+kHCEYYQj/eZxhi9LrGD7AF8m40+Mc++40l0Sx3FRbC5a2GNIbLwZ/TYi7BRPydLtMuCTxEazhXAAsBvQHdgAtC9rHJROXBuBHjHGsYSN5LY4j68kJvps8znRfjDwb2DPuNxWEXZGdQin6VYBLWLZNxPxnhSXaTJxHUs42oayE9dBsX2HOI53CTuvBrHbAqBHGetCc6BfnMdNCIl7fLb5k5hHpRJXefMj0e+JhA3qTuB4EjvKzOUDNAaeBN6IzZ8Cl2UZ5ilx3WiUJZ5jCevjeWVNTzmxVjRfphAO5NoRdhZTgBGxW0fCzuKkuFx/Q1gvK0xchJ3OY8C0srbfCuKeAlxBSKQ9Y7u3geMombiGA28B+wItCAdWd8VuZxJ2hIfFeJ6kZOL6LTCBsE03IRys3pPYBgvj7/aEg72WifXlO2XEPZZwNqVJLDcH+I/E9G8mbH91gR8S9juWZTh1CQeH98XYGwInxG5tCdvfbnGaXwd+m7FufEBIfLuXEWe2xFU/Lt+i+X0WIUkaYR/0JXBU5vxJ9H80YV2tF6d9NvHgn7AvehdoGofXAfj29iyHCteZXAplmREnxIWyT2z+N3B9/F2HsKEfkaW/W4Hnylt5y1rx48z/XgVxrS4aL+FI6dwyys0GTo+/ryHLUXHsdgnwdka7N/mmljiGshNXQYz5kCzdzgQ2JxbWFmCPRPdxwM/KGgelE9ekRLdzCDugoiPXJjGOptnmc2J5fga0i803A49llHkJGEg4+MiM90lKJq67EvG3JnviahjbHw90Az7Osq78b7Z1Icv87AyszjZ/EvNohxJXLNOTsIGtifP2N4n5O4Zw4LCGkKgmEHd0cT6dmWV4h8QYWiXiWUPYbpxwRGoZ07M+lllDOTXACubLFOCnieargb/F38OApxPd9iDU9MtLXEXTvY1wsNQpY/vdkoh5DXBseds+4UD0KULymBO7JRPXfKBXor8ehFPIAI8Qk3BsbhfnZVvCznMDiQRESIoLE9tgUeJqS9gWTgPqlzNv6xJqOB0T7f4fMCUx/fMS3RrFeL6VZVjHEWpapWpLWcr2Ad7PWDcur6CfUokrtv+UeHYnS7fxwHWZ86eccfyYuG8HvkdI4scSKwixfc7LoaK/Hb3GNRB42d1XxuYnYzsINZiGhJUs0/5ltM/VkmSDmd1oZrPjhe81hKr7PjmM61HCRkL8/1gZ5VoSTgEkLSbscCqymrBBfztLt28Tqv3FZb3k9ZvFcdy5Wp74/RWw0t23Jpoh1AZKMbP9CYlyoLvPia0PBM43szVFf4Tk9u0YV7Z4kzKvb2VTNA8/j+NrmTG+nwD7lRFzIzN7wMwWm9kXhKPQpmZWt4Jx7hB3f9HdzyEcJZ5L2Cklb+gY6e5N3f1b7t7b3YvWu5WUvfy3EdaRIvsQltFNhA24fkY/feI4mrp7n2xx5jhfPk38/pJv1ouWJLavuHxXZRtPwkh3b0pIvl8REk7SW4mYm7r7WxUM78+End6PyL5NZm6Pye2kRPwZ5VoQEse7ifXrb7F9Ce4+j7ATvgP4zMyeLuNGmH0IZwcy40nuG4rntbt/GX9m2w73Bxa7+5bMDma2b4xhaVymj/PNPq7Iksz+KmJm9QnT/3ls7mlmb8UbbdYQtuHM8ST7bxdvsvo0xvWLovLu/grh1OFoYLmZPWhme7Idy6Ei2524zGx34AKgewz6U+B64Ih4V9BKwpHYd7L0vqSM9hAycaNE87eylPFEHCcSagYXAM3iBrSWkNUrGtfjwLkx3g6Eo4tslhF2qkkHEM5/lytu+G8C52fpfAEwOdHczMz2yBhH0cV3J0/ishxPOPXwYqLTEkKNK7nT2cPdRxDOZ2eLt2iY3yLsmN+rYPTnEY5sP4rjW5gxvibu3quMfm8k7CS7ufuehNNb8M2yz2VdKpLz/HX3be4+mXDN4bAcevk/yl7+byZ2ZkXD3+ruvyZsP1fnGldCRfOlPJ8QdqChB7NGhFOPFXL3j4HrgPvjOrVD4vx4kXBaLVviytwek9tJifhJrJOEfdJXwKGJ9WsvDzcrZIvjSXc/IY7LgV9mKbaScNYpM54K9w1ZLAEOKOPmintiDJ3iMv0BpZfnjuwjziXUiN+2cFf4s4Sa/n5xXzoxMZ5sw/894UzbwTGunyTjcvf/cvejgUMJtd+hVLwccp6OHalx9SGcn+9IOBXRmbDzfwO41N23EartvzGzlmZW18yOizPnCeA0M7vAzOqZWXMz6xyH+wHQNx41tiVclC9PE8KMXwHUM7NhhOszRR4G7jKzgy3oFO9Uwt0LCTdaPAY86+5fkd1EoJ2FW7vrxVtIOxIueufiFmCgmV1rZk3MrJmFW9uPI1wzSbrTzBrEhHw24foEhNpUvu5AfAT4t7v/KqP948A5ZtYjLr+G8VbVAndfDExNxHsC4fRkkV6E009ZV0ILz75dA9wO3BrXl7eBL8zsZjPbPY7zMDPrWkbcTQgbwBoz2zsOK+kDoL+Z1TezLoQL0WUpd/6a2blm1j8uOzOzYwjXACqqPUBYxt81s7vNbO+4DvyIcPfczeX0NwL4TzNrmMM4kiqaL+V5BjjbzE4wswaE60k57x/cfRIhiQzejnFm8xOgu7svytLtKeCnZtbCzPYhnN4sen5wHDDIwuMHjUhMe1zHHgLuM7N9AcyslZn1yByBmbU3s+/F/dVGwvzcmlkuntEYB9wdl+uBwA2JeLbH24TEO8LM9ojb2/GxWxPiaWIza0VIADssrocXE2pDv3T3VYSa426EfekWM+sJnJHobTnQ3Mz2SrRrQrgpbr2FxyB+mBhHVzPrFmt1G4g3V+WwHLKNJ6sdSVwDCdcePnb3T4v+CFXDi+NRw02EGyPeIVRFf0k41/kxYcd2Y2z/AeGmCQgXJjfF4B8lJLnyvEQ4OptDqKJvpGSV+TeEFetlwgz+H8IF6SKPAodT9mlC4kI9O8a7CvhP4OzEKdJyufvfCefh+xJWzMXAkYQLr3MTRT8lnDZaRpjuq9z937Hb/wAdY9W6rJrhjuoPnGcln7U50d2XEI7IfkJYmZcQNpii9WUA4brU54QdxNjEMMs6TbjGzDYQ1otewPnu/ggU7wTOIRwELSQcmT1MOPWbzW8Jy3IlIYH8LaP7zwi17dWE5PFkOfPgfuD7Fp5H/K8s3VcTLrDPJaxHjwP3untF6ydxGZ9AWMcXEdaBfoSbTv5RTq8vJMa7PSqaL+XFOpNw1+6TMc7VhOtL2+NeQsLd4ec63X1Z3G6y+TnhoGk6YT16L7YjnjH4LaE2PC/+T7o5tn8rntr6P0qf2oSwAx9BmIefEm4E+UkZ8fyIsGNeQLiD8EnCweB2Saz/bQl32xUCRc9Z3QkcRTib9ALhdOqOmGZm6wnz4ArCPQnD4vjXAdcS9perCdv3hER8/yYcNCyI+6GWhH38AMKNWQ8Bf0yMa8/YbjXf3KU6MnYrczmUMZ6srIwD4xrPzE4i7IRaxyOBqorjZMKNAwVVFUNliQctnxIuvq6t6nhEpGaqla98ilXY64CHqzJp1UB7E+4mVNISkbzJW+Iys0fM7DMzm1FGdzOz/zKzeWY23cyOSnQbaGZz49/AbP3vRFwdCLfmfptwakEqibt/5u6/r+o4RKRmy9upwngqbj3hDRGl7sAys16Ec8S9CNdL7nf3bvGi8lSgC+Euk3eBo919deYwRESk9slbjcvdXyc+I1CGcwlJzePzHU3N7NuEmxkmufvnMVlNIjywKyIiQlW+rLYVJe8CLIztympfipkNJt5+u8ceexx9yCGHZCsmIiJZvPvuuyvdfbsfAK5qVZm4sj0U6eW0L93S/UHgQYAuXbr41KlTKy86EZEazswy33qTClV5V2EhJZ90LyA8x1RWexERkSpNXBOAS+PdhccCa939E8KDxWfENxU0IzzB/VIVxikiItVI3k4VmtlThJeF7mPhq5a3E18c6u5/ILxdoRfhKeovCZ/6wN0/N7O7CG/dABju7uXd5CEiIrVI3hKXu19UQXcnvGImW7dH2IFXp4hIum3evJnCwkI2btxY1aHUKA0bNqSgoID69TM/OpBOVXlzhohICYWFhTRp0oTWrVtjlstL7aUi7s6qVasoLCykTZs2VR1OpaiVr3wSkepp48aNNG/eXEmrEpkZzZs3r1G1WCUuEalWlLQqX02bp0pcIiKSKrrGJSLVVutbXqjU4S0acVZO5QoLCxkyZAizZs1i27ZtnH322dx77700aNAga/k1a9bw5JNPcvXV4cPVy5Yt49prr+WZZ57JObZhw4Zx0kkncdppp+XcTzaNGzdm/fr1OzWM6k41LhGRBHenb9++9OnTh7lz5zJnzhzWr1/PbbfdVmY/a9as4Xe/+11xc8uWLbcraQEMHz58p5NWbaEal4hIwiuvvELDhg257LLLAKhbty733Xcfbdq0oU2bNrz00kt8/fXXLFy4kAEDBnD77bdzyy23MH/+fDp37szpp5/OkCFDOPvss5kxYwZjxoxh/PjxbN26lRkzZnDjjTeyadMmHnvsMXbbbTcmTpzI3nvvzaBBgzj77LNp3bo1V1xxBUBxP+7O/PnzGTJkCCtWrKBRo0Y89NBDHHLIIcVxbNmyhTPPrB3vI1fiEhFJmDlzJkcffXSJdnvuuScHHHAAW7Zs4e2332bGjBk0atSIrl27ctZZZzFixAhmzJjBBx98AMCiRYtK9D9jxgzef/99Nm7cSNu2bfnlL3/J+++/z/XXX8/YsWP58Y9/XFy2S5cuxcMZOnRocTIaPHgwf/jDHzj44IP517/+xdVXX80rr7zCddddxw9/+EMuvfRSRo8encc5U30ocYmIJLh71rvwitqffvrpNG/eHIC+ffvy97//nT59+pQ7zFNOOYUmTZrQpEkT9tprL8455xwADj/8cKZPn561n3HjxvHee+/x8ssvs379ev75z39y/vnnF3f/+uuvAfjHP/7Bs88+C8All1zCzTffvP0TnTJKXCIiCYceemhxIijyxRdfsGTJEurWrVsqqeVyq/luu+1W/LtOnTrFzXXq1GHLli2lys+cOZPbb7+d119/nbp167Jt2zaaNm1aXBPLVNNud6+Ibs4QEUk49dRT+fLLLxk7diwQrjPdeOONDBo0iEaNGjFp0iQ+//xzvvrqK8aPH8/xxx9PkyZNWLduXaWMf+3atfTv35+xY8fSokX4VNaee+5JmzZt+NOf/gSE2t+0adMAOP7443n66acBeOKJJyolhupONS4RqbZyvX29MpkZzz33HFdffTV33XUX27Zto1evXvziF7/gqaee4oQTTuCSSy5h3rx5DBgwgC5dugAhgRx22GH07NmTIUOyvoY1J+PHj2fx4sVceeWVxe0++OADnnjiCX74wx/y85//nM2bN9O/f3+OOOII7r//fgYMGMD9999Pv379dnr608DCu27TTx+SFEm/2bNn06FDh6oOo0xjxoxh6tSpjBo1qqpD2W7Z5q2ZvevuXaoopB2mU4UiIpIqOlUoIpKjQYMGMWjQoKoOo9ZTjUtERFJFiUtERFJFiUtERFJFiUtERFJFN2eISPV1x16VPLy1FRapW7cuhx9+OJs3b6ZevXoMHDiQH//4x9SpU/5x/tChQ5k4cSK9evXi3nvv3e7Qij5HsmjRIv75z38yYMCA7R5GbaHEJSKSsPvuuxe/Wumzzz5jwIABrF27ljvvvLPc/h544AFWrFhR4vVOO2LRokU8+eSTSlzl0KlCEZEy7Lvvvjz44IOMGjUKd2fr1q0MHTqUrl270qlTJx544AEAevfuzYYNG+jWrRt//OMfef755+nWrRtHHnkkp512GsuXLwfgjjvuYOTIkcXDP+yww0q9Sf6WW27hjTfeoHPnztx33327bFrTRDUuEZFyHHTQQWzbto3PPvuMv/zlL+y111688847fP311xx//PGcccYZTJgwgcaNGxfX1FavXs1bb72FmfHwww/zq1/9il//+tc5jW/EiBGMHDmSv/71r/mcrFRT4hIRqUDRq/Fefvllpk+fXvx147Vr1zJ37lzatGlTonxhYSEXXnghn3zyCZs2bSrVXXaOEpeISDkWLFhA3bp12XfffXF3/vu//5sePXqU28+PfvQjbrjhBnr37s2UKVO44447AKhXrx7btm0rLrdx48Z8hl5j6RqXiEgZVqxYwVVXXcU111yDmdGjRw9+//vfs3nzZgDmzJnDhg0bSvW3du1aWrVqBcCjjz5a3L5169a89957ALz33nssXLiwVL+V+YmUmko1LhGpvnK4fb2yffXVV3Tu3Ln4dvhLLrmEG264AYArrriCRYsWcdRRR+HutGjRgvHjx5caxh133MH5559Pq1atOPbYY4sTVL9+/Rg7diydO3ema9eutGvXrlS/nTp1ol69ehxxxBEMGjSI66+/Pr8TnEL6rImIVBvV/bMmaabPmoiIiFQRJS4REUkVJS4RqVZqyuWL6qSmzVMlLhGpNho2bMiqVatq3I62Krk7q1atomHDhlUdSqXRXYUiUm0UFBRQWFjIihUrqjqUGqVhw4YUFBRUdRiVRolLRKqN+vXr6y0TUiGdKhQRkVRR4hIRkVRR4hIRkVRR4hIRkVTJa+IyszPN7CMzm2dmt2TpfqCZTTaz6WY2xcwKEt1+ZWYzzWy2mf2XmVk+YxURkXTIW+Iys7rAaKAn0BG4yMw6ZhQbCYx1907AcOCe2O93geOBTsBhQFege75iFRGR9MhnjesYYJ67L3D3TcDTwLkZZToCk+PvVxPdHWgINAB2A+oDy/MYq4iIpEQ+E1crYEmiuTC2S5oG9Iu/zwOamFlzd3+TkMg+iX8vufvszBGY2WAzm2pmU/XAoohI7ZDPxJXtmlTme1xuArqb2fuEU4FLgS1m1hboABQQkt33zOykUgNzf9Ddu7h7lxYtWlRu9CIiUi3l880ZhcD+ieYCYFmygLsvA/oCmFljoJ+7rzWzwcBb7r4+dnsROBZ4PY/xiohICuSzxvUOcLCZtTGzBkB/YEKygJntY2ZFMdwKPBJ/f0yoidUzs/qE2lipU4UiIlL75C1xufsW4BrgJULSGefuM81suJn1jsVOBj4ysznAfsDdsf0zwHzgQ8J1sGnu/ny+YhURkfSwmvL5gC5duvjUqVOrOgwRkdQws3fdvUtVx7G99OYMERFJFSWuSva3v/2N9u3b07ZtW0aMGFGq++LFizn11FPp1KkTJ598MoWFhQC8+uqrdO7cufivYcOGjB8/HoCLL76Y9u3bc9hhh3H55ZezefPmXTpNIiLVirvXiL+jjz7aq9qWLVv8oIMO8vnz5/vXX3/tnTp18pkzZ5Yo8/3vf9/HjBnj7u6TJ0/2H/zgB6WGs2rVKm/WrJlv2LDB3d1feOEF37Ztm2/bts379+/vv/vd7/I/MSJS4wFTvRrsv7f3TzWuSvT222/Ttm1bDjroIBo0aED//v35y1/+UqLMrFmzOPXUUwE45ZRTSnUHeOaZZ+jZsyeNGjUCoFevXpgZZsYxxxxTXEsTEamNlLgq0dKlS9l//28eXSsoKGDp0qUlyhxxxBE8++yzADz33HOsW7eOVatWlSjz9NNPc9FFF5Ua/ubNm3nsscc488wz8xC9iEg6KHFVIs9yh2bmS+1HjhzJa6+9xpFHHslrr71Gq1atqFfvm+fAP/nkEz788EN69OhRalhXX301J510EieeeGLlBy8ikhL5fHNGrVNQUMCSJd+8nrGwsJCWLVuWKNOyZUv+/Oc/A7B+/XqeffZZ9tprr+Lu48aN47zzzqN+/fol+rvzzjtZsWIFDzzwQB6nQESk+lONqxJ17dqVuXPnsnDhQjZt2sTTTz9N7969S5RZuXIl27ZtA+Cee+7h8ssvL9H9qaeeKnWa8OGHH+all17iqaeeok4dLTIRqd20F6xE9erVY9SoUfTo0YMOHTpwwQUXcOihhzJs2DAmTAhvu5oyZQrt27enXbt2LF++nNtuu624/0WLFrFkyRK6dy/56bGrrrqK5cuXc9xxx9G5c2eGDx++S6dLRKQ60ZszRERqKb05Q0RqlB19mB7g448/5owzzqBDhw507NiRRYsWATB58mSOOuooOnfuzAknnMC8efN21eTkrLZOd6pU9YNklfVXHR5AFqkpdvZh+u7du/vLL7/s7u7r1q0rfpj+4IMP9lmzZrm7++jRo33gwIG7YGpyV9umGz2ALCI1xc48TD9r1iy2bNnC6aefDkDjxo2LH6Y3M7744gsA1q5dW+qu26pWW6c7bXQ7fNT6lheqOoQqsWjEWVUdglRD2R6m/9e//lWiTNHD9Nddd12Jh+nnzJlD06ZN6du3LwsXLuS0005jxIgR1K1bl4cffphevXqx++67s+eee/LWW2/t6kkrV22d7rRRjUtESvGdeJh+y5YtvPHGG4wcOZJ33nmHBQsWMGbMGADuu+8+Jk6cSGFhIZdddhk33HDDrpicnNXW6U4b1bhEpJSdeZi+oKCAI488koMOOgiAPn368NZbb9G7d2+mTZtGt27dALjwwgur3evLaut0p41qXCJSys48TN+1a1dWr17NihUrAHjllVfo2LEjzZo1Y+3atcyZMweASZMm0aFDh104VRWrrdOdNqpxiUgpyYfpt27dyuWXX178MH2XLl3o3bs3U6ZM4dZbb8XMOOmkkxg9ejQAdevWZeTIkZx66qlFd/xy5ZVXUq9ePR566CH69etHnTp1aNasGY888kgVT2lJtXW600YPIEe6OUNEahs9gCwiIrILKHGJiEiq6BqXSA1SW095AyxqOKCqQ6gad6yt6gh2OdW4REQkVZS4REQkVZS4REQkVZS4REQkVZS4REQkVZS4REQkVZS4REQkVZS4JC/y8flzERFQ4pI82Lp1K0OGDOHFF19k1qxZPPXUU8yaNatEmZtuuolLL72U6dOnM2zYMG699dbibpdeeilDhw5l9uzZvP322+y77767ehJEpBpT4pJKl6/Pn4uIgBKX5EG2z58vXbq0RJmiz58DZX7+/Mgjj2To0KFs3bp1l8YvItWbEpdUunx9/lxEBJS4JA+25/Pn77//PnfffTdAqc+f16tXjz59+vDee+/t0vhFpHpT4pJKl4/Pn4uIFFHikkqX/Px5hw4duOCCC4o/fz5hwgQApkyZQvv27WnXrh3Lly/ntttuA0p+/vzwww/H3bnyyiurcnJEpJqxbNcj0qhLly4+derUHe6/tn7HaNGIs6o6BKlEtXU9Bn2Pa0eY2bvu3qUSo8soAqQAAA24SURBVNkl8lrjMrMzzewjM5tnZrdk6X6gmU02s+lmNsXMChLdDjCzl81stpnNMrPW+YxVRETSIW+Jy8zqAqOBnkBH4CIzy7xYMRIY6+6dgOHAPYluY4F73b0DcAzwWb5iFRGR9MhnjesYYJ67L3D3TcDTwLkZZToCk+PvV4u6xwRXz90nAbj7enf/Mo+xiohIStTL47BbAUsSzYVAt4wy04B+wP3AeUATM2sOtAPWmNmfgTbA/wG3uHuJJ1HNbDAwGOCAAw7IxzTUfHfsVdURVJ2duDYgIlUnnzUuy9Iu806Qm4DuZvY+0B1YCmwhJNQTY/euwEHAoFIDc3/Q3bu4e5cWLVpUYugiIlJd5TNxFQL7J5oLgGXJAu6+zN37uvuRwG2x3drY7/vxNOMWYDxwVB5jFRGRlKgwcZnZNWbWbAeG/Q5wsJm1MbMGQH9gQsaw9zGzohhuBR5J9NvMzIqqUd8DSr5eXEREaqVcalzfAt4xs3Hx9vZspwBLiTWla4CXgNnAOHefaWbDzazoNQonAx+Z2RxgP+Du2O9WwmnCyWb2IeG040PbMV0iIlJDVXhzhrv/1Mx+BpwBXAaMMrNxwP+4+/wK+p0ITMxoNyzx+xngmTL6nQR0qnAKRESkVsnpGpeH12t8Gv+2AM2AZ8zsV3mMTUREpJQKa1xmdi0wEFgJPAwMdffN8drUXOA/8xuiiIjIN3J5jmsfoK+7L062dPdtZnZ2fsISERHJLpdThROBz4sazKyJmXUDcPfZ+QpMREQkm1wS1++B9YnmDbGdiIjILpdL4jJPfPvE3beR31dFiYiIlCmXxLXAzK41s/rx7zpgQb4DExERySaXxHUV8F3CewSLXpQ7OJ9BiYiIlCWXB5A/I7yuSUREpMrl8hxXQ+A/gEOBhkXt3f3yPMYlIiKSVS6nCh8jvK+wB/Aa4S3v6/IZlIiISFlySVxt3f1nwAZ3fxQ4Czg8v2GJiIhkl0vi2hz/rzGzw4C9gNZ5i0hERKQcuTyP9WD8HtdPCd/Tagz8LK9RiYiIlKHcxBVfpPuFu68GXgcO2iVRiYiIlKHcU4XxLRnX7KJYREREKpTLNa5JZnaTme1vZnsX/eU9MhERkSxyucZV9LzWkEQ7R6cNRUSkCuTy5ow2uyIQERGRXOTy5oxLs7V397GVH46IiEj5cjlV2DXxuyFwKvAeoMQlIiK7XC6nCn+UbDazvQivgRIREdnlcrmrMNOXwMGVHYiIiEgucrnG9TzhLkIIia4jMC6fQYmIiJQll2tcIxO/twCL3b0wT/GIiIiUK5fE9THwibtvBDCz3c2stbsvymtkIiIiWeRyjetPwLZE89bYTkREZJfLJXHVc/dNRQ3xd4P8hSQiIlK2XBLXCjPrXdRgZucCK/MXkoiISNlyucZ1FfCEmY2KzYVA1rdpiIiI5FsuDyDPB441s8aAufu6/IclIiKSXYWnCs3sF2bW1N3Xu/s6M2tmZj/fFcGJiIhkyuUaV093X1PUEL+G3Ct/IYmIiJQtl8RV18x2K2ows92B3copLyIikje53JzxODDZzP43Nl8GPJq/kERERMqWy80ZvzKz6cBpgAF/Aw7Md2AiIiLZ5Pp2+E8Jb8/oR/ge1+y8RSQiIlKOMhOXmbUzs2FmNhsYBSwh3A5/iruPKqu/jGGcaWYfmdk8M7slS/cDzWyymU03sylmVpDRfU8zW5p4hkxERGq58mpc/ybUrs5x9xPc/b8J7ynMiZnVBUYDPQmfQrnIzDpmFBsJjHX3TsBw4J6M7ncBr+U6ThERqfnKS1z9CKcIXzWzh8zsVMI1rlwdA8xz9wXx/YZPA+dmlOkITI6/X012N7Ojgf2Al7djnCIiUsOVmbjc/Tl3vxA4BJgCXA/sZ2a/N7Mzchh2K8LpxSKFsV3SNEKCBDgPaGJmzc2sDvBrYGh5IzCzwWY21cymrlixIoeQREQk7Sq8OcPdN7j7E+5+NlAAfACUul6VRbbamWc03wR0N7P3ge7AUsLHKq8GJrr7Esrh7g+6exd379KiRYscQhIRkbTL5TmuYu7+OfBA/KtIIbB/orkAWJYxvGVAX4D4LsR+7r7WzI4DTjSzq4HGQAMzW+/uuSRMERGpwbYrcW2nd4CDzawNoSbVHxiQLGBm+wCfu/s24FbgEQB3vzhRZhDQRUlLREQg9+e4tpu7bwGuAV4iPPc1zt1nmtnwxPe9TgY+MrM5hBsx7s5XPCIiUjPks8aFu08EJma0G5b4/QzwTAXDGAOMyUN4IiKSQnmrcYmIiOSDEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKSKEpeIiKRKXhOXmZ1pZh+Z2TwzuyVL9wPNbLKZTTezKWZWENt3NrM3zWxm7HZhPuMUEZH0yFviMrO6wGigJ9ARuMjMOmYUGwmMdfdOwHDgntj+S+BSdz8UOBP4rZk1zVesIiKSHvmscR0DzHP3Be6+CXgaODejTEdgcvz9alF3d5/j7nPj72XAZ0CLPMYqIiIpkc/E1QpYkmgujO2SpgH94u/zgCZm1jxZwMyOARoA8zNHYGaDzWyqmU1dsWJFpQUuIiLVVz4Tl2Vp5xnNNwHdzex9oDuwFNhSPACzbwOPAZe5+7ZSA3N/0N27uHuXFi1UIRMRqQ3q5XHYhcD+ieYCYFmyQDwN2BfAzBoD/dx9bWzeE3gB+Km7v5XHOEVEJEXyWeN6BzjYzNqYWQOgPzAhWcDM9jGzohhuBR6J7RsAzxFu3PhTHmMUEZGUyVvicvctwDXAS8BsYJy7zzSz4WbWOxY7GfjIzOYA+wF3x/YXACcBg8zsg/jXOV+xiohIeuTzVCHuPhGYmNFuWOL3M8AzWfp7HHg8n7GJiEg66c0ZIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKkpcIiKSKnlNXGZ2ppl9ZGbzzOyWLN0PNLPJZjbdzKaYWUGi20Azmxv/BuYzThERSY+8JS4zqwuMBnoCHYGLzKxjRrGRwFh37wQMB+6J/e4N3A50A44BbjezZvmKVURE0iOfNa5jgHnuvsDdNwFPA+dmlOkITI6/X0107wFMcvfP3X01MAk4M4+xiohIStTL47BbAUsSzYWEGlTSNKAfcD9wHtDEzJqX0W+rzBGY2WBgcGxcb2YfVU7otYfBPsDKqo6jStxpVR2BVKJauy7v3Hp8YGWFsSvlM3Flm5ue0XwTMMrMBgGvA0uBLTn2i7s/CDy4c2HWbmY21d27VHUcIjtL63Ltkc/EVQjsn2guAJYlC7j7MqAvgJk1Bvq5+1ozKwROzuh3Sh5jFRGRlMjnNa53gIPNrI2ZNQD6AxOSBcxsHzMriuFW4JH4+yXgDDNrFm/KOCO2ExGRWi5vicvdtwDXEBLObGCcu880s+Fm1jsWOxn4yMzmAPsBd8d+PwfuIiS/d4DhsZ1UPp1qlZpC63ItYe6lLh2JiIhUW3pzhoiIpIoSl4iIpIoSVw1gZreZ2cz46qwPzKxbfIXWR2Y2zcz+YWbtY9kGZvZbM5sfX6f1l4xXbW2Nw5hhZs+bWdPYvrWZfRW7Ff01qKppltpL66gocaWcmR0HnA0cFV+ddRrfPLx9sbsfATwK3Bvb/QJoArRz94OB8cCfzazo2bmv3L2zux8GfA4MSYxufuxW9Lcpv1MnkpXW0VpOiSv9vg2sdPevAdx9ZXw+Lul1oK2ZNQIuA653962x/P8CXwPfyzLsN8nyxhKRakTraC2kxJV+LwP7m9kcM/udmXXPUuYc4EOgLfCxu3+R0X0qcGiyRXxJ8qmUfPbuO4lTMKMrbxJEtp/W0dorn2/OkF3A3deb2dHAicApwB8Tn5B5wsy+AhYBPwL2Jsurswiv2Cpqv7uZfQC0Bt4lvOC4yHx371zpEyGyfbSO1nKqcdUA7r7V3ae4++2Eh777xU4Xx/P8fdx9CTAPONDMmmQM4ihgVvz9VdzwDwQaUPL6gUh1oHW0llPiSjkza29mBydadQYWZyvr7hsIN2r8Jp5mwcwuBRoBr2SUXQtcC9xkZvXzEbvIztA6WnspcaVfY+BRM5tlZtMJ3zi7o5zytwIbgTlmNhc4HzjPs7xCxd3fJ3x6pn+lRy1SCbSO1k565ZOIiKSKalwiIpIqSlwiIpIqSlwiIpIqSlwiIpIqSlwiIpIqSlwiIpIqSlwiIpIq/x8LdToD4QV87QAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "labels = [\"SPORF\", \"RF\"]\n", + "rerf_acc = [rerf_acc_opti, rerf_acc_default]\n", + "rf_acc = [rf_acc_opti, rf_acc_default]\n", + "\n", + "x = np.arange(len(labels))\n", + "width = 0.35\n", + "\n", + "fig, ax = plt.subplots()\n", + "rects1 = ax.bar(x - width / 2, rerf_acc, width, label=\"Optimized\")\n", + "rects2 = ax.bar(x + width / 2, rf_acc, width, label=\"Default\")\n", + "\n", + "# Add some text for labels, title and custom x-axis tick labels, etc.\n", + "ax.set_ylabel(\"Accuracy\")\n", + "ax.set_title(\"Accuracy of Optimized/Default SPORF and RF Models on car Dataset\")\n", + "ax.set_xticks(x)\n", + "ax.set_xticklabels(labels)\n", + "ax.legend()\n", + "\n", + "\n", + "def autolabel(rects):\n", + " \"\"\"Attach a text label above each bar in *rects*, displaying its height.\"\"\"\n", + " for rect in rects:\n", + " height = float(\"%.3f\" % (rect.get_height()))\n", + " ax.annotate(\n", + " \"{}\".format(height),\n", + " xy=(rect.get_x() + rect.get_width() / 2, height),\n", + " xytext=(0, 3), # 3 points vertical offset\n", + " textcoords=\"offset points\",\n", + " ha=\"center\",\n", + " va=\"bottom\",\n", + " )\n", + "\n", + "\n", + "autolabel(rects1)\n", + "autolabel(rects2)\n", + "fig.tight_layout()\n", + "plt.ylim((0.9, 1))\n", + "\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/model_selection/RF_vs_SPORF_random_search.py b/examples/model_selection/RF_vs_SPORF_random_search.py new file mode 100644 index 0000000000000..34497549c4e47 --- /dev/null +++ b/examples/model_selection/RF_vs_SPORF_random_search.py @@ -0,0 +1,216 @@ +""" +============================================================================ +Demonstration of randomized search to compare classifier performance +============================================================================ +An important step in classifier performance comparison is hyperparameter +optimization. Here, we specify the classifer models we want to tune and a +dictionary of hyperparameter ranges (preferably similar for fairness in +comparision) for each classifier. Then, we find the optimal hyperparameters +through a function that uses RandomizedSearchCV and refit the optimized +models to obtain accuracies. We can see clearly in the plot that the +optimized models perform better than or similar to the default parameter +models. On the dataset we use in this example, car dataset from OpenML-CC18, +SPORF also performs better than RF overall. +""" +print(__doc__) + +from sklearn.model_selection import RandomizedSearchCV + +import pandas as pd +import numpy as np +import math +from rerf.rerfClassifier import rerfClassifier +from sklearn.ensemble import RandomForestClassifier +from sklearn.datasets import fetch_openml +from sklearn.model_selection import train_test_split +from sklearn import metrics +from warnings import simplefilter +simplefilter(action="ignore", category=FutureWarning) +from warnings import simplefilter +simplefilter(action="ignore", category=FutureWarning) + +import matplotlib +import matplotlib.pyplot as plt + + +def hyperparameter_optimization_random(X, y, *argv): + """ + Given a classifier and a dictionary of hyperparameters, find optimal hyperparameters using RandomizedSearchCV. + + Parameters + ---------- + X : numpy.ndarray + Input data, shape (n_samples, n_features) + y : numpy.ndarray + Output data, shape (n_samples, n_outputs) + *argv : list of tuples (classifier, hyperparameters) + List of (classifier, hyperparameters) tuples: + + classifier : sklearn-compliant classifier + For example sklearn.ensemble.RandomForestRegressor, rerf.rerfClassifier, etc + hyperparameters : dictionary of hyperparameter ranges + See https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html. + + Returns + ------- + clf_best_params : dictionary + Dictionary of best hyperparameters + """ + + clf_best_params = {} + + # Iterate over all (classifier, hyperparameters) pairs + for clf, params in argv: + + # Run randomized search + n_iter_search = 10 + random_search = RandomizedSearchCV( + clf, param_distributions=params, n_iter=n_iter_search, cv=10, iid=False + ) + random_search.fit(X, y) + + # Save results + clf_best_params[clf] = random_search.best_params_ + + return clf_best_params + + +############################################################################### +# Building classifiers and specifying parameter ranges to sample from +# ---------------------------------------------------------- +# + +# get some data +X, y = fetch_openml(data_id=40975, return_X_y=True, as_frame=True) #car dataset +y = pd.factorize(y)[0] +X = X.apply(lambda x: pd.factorize(x)[0]) +n_features = np.shape(X)[1] +n_samples = np.shape(X)[0] + +# build a classifier +rerf = rerfClassifier() + +# specify max_depth and min_sample_splits ranges +max_depth_array_rerf = (np.unique(np.round((np.linspace(2, n_samples, 10))))).astype( + int +) +max_depth_range_rerf = np.append(max_depth_array_rerf, None) + +min_sample_splits_range_rerf = ( + np.unique( + np.round((np.arange(1, math.log(n_samples), (math.log(n_samples) - 2) / 10))) + ) +).astype(int) + +# specify parameters and distributions to sample from +rerf_param_dict = { + "n_estimators": np.arange(50, 550, 50), + "max_depth": max_depth_range_rerf, + "min_samples_split": min_sample_splits_range_rerf, + "feature_combinations": [1, 2, 3, 4, 5], + "max_features": ["sqrt", "log2", None, n_features ** 2], +} + +# build another classifier +rf = RandomForestClassifier() + +# specify max_depth and min_sample_splits ranges +max_depth_array_rf = (np.unique(np.round((np.linspace(2, n_samples, 10))))).astype(int) +max_depth_range_rf = np.append(max_depth_array_rf, None) + +min_sample_splits_range_rf = ( + np.unique( + np.round((np.arange(2, math.log(n_samples), (math.log(n_samples) - 2) / 10))) + ) +).astype(int) + +# specify parameters and distributions to sample from +rf_param_dict = { + "n_estimators": np.arange(50, 550, 50), + "max_depth": max_depth_range_rf, + "min_samples_split": min_sample_splits_range_rf, + "max_features": ["sqrt", "log2", None], +} + +############################################################################### +# Obtaining best parameters dictionary and refitting +# ---------------------------------------------------------- +# + +best_params = hyperparameter_optimization_random( + X, y, (rerf, rerf_param_dict), (rf, rf_param_dict) +) +print(best_params) + +# extract values from dict - seperate each classifier's param dict +keys, values = zip(*best_params.items()) + +# train test split +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.33, random_state=42 +) + +# get accuracies of optimized and default models +rerf_opti = rerfClassifier(**values[0]) +rerf_opti.fit(X_train, y_train) +rerf_pred_opti = rerf_opti.predict(X_test) +rerf_acc_opti = metrics.accuracy_score(y_test, rerf_pred_opti) + +rerf_default = rerfClassifier() +rerf_default.fit(X_train, y_train) +rerf_pred_default = rerf_default.predict(X_test) +rerf_acc_default = metrics.accuracy_score(y_test, rerf_pred_default) + +rf_opti = RandomForestClassifier(**values[1]) +rf_opti.fit(X_train, y_train) +rf_pred_opti = rf_opti.predict(X_test) +rf_acc_opti = metrics.accuracy_score(y_test, rf_pred_opti) + +rf_default = RandomForestClassifier() +rf_default.fit(X_train, y_train) +rf_pred_default = rf_default.predict(X_test) +rf_acc_default = metrics.accuracy_score(y_test, rf_pred_default) + +############################################################################### +# Plotting the result +# ------------------- + +labels = ["SPORF", "RF"] +rerf_acc = [rerf_acc_opti, rerf_acc_default] +rf_acc = [rf_acc_opti, rf_acc_default] + +x = np.arange(len(labels)) +width = 0.35 + +fig, ax = plt.subplots() +rects1 = ax.bar(x - width / 2, rerf_acc, width, label="Optimized") +rects2 = ax.bar(x + width / 2, rf_acc, width, label="Default") + +# Add some text for labels, title and custom x-axis tick labels, etc. +ax.set_ylabel("Accuracy") +ax.set_title("Accuracy of Optimized/Default SPORF and RF Models on car Dataset") +ax.set_xticks(x) +ax.set_xticklabels(labels) +ax.legend() + + +def autolabel(rects): + """Attach a text label above each bar in *rects*, displaying its height.""" + for rect in rects: + height = float("%.3f" % (rect.get_height())) + ax.annotate( + "{}".format(height), + xy=(rect.get_x() + rect.get_width() / 2, height), + xytext=(0, 3), # 3 points vertical offset + textcoords="offset points", + ha="center", + va="bottom", + ) + + +autolabel(rects1) +autolabel(rects2) +fig.tight_layout() +plt.ylim((0.9, 1)) + +plt.show()