From c97f35d6e999694d8c5db7a2b80140a961f394c9 Mon Sep 17 00:00:00 2001 From: Neil Zhang Date: Fri, 10 Jul 2020 10:38:32 -0400 Subject: [PATCH] Added problems 9.1 --- ...o Chapter 6 Similarity-Based Methods.ipynb | 33 +--- Solutions to Chapter 9 Learning Aides.ipynb | 182 +++++++++++++++++- libs/data_util.py | 55 ++++++ libs/nn.py | 7 +- 4 files changed, 244 insertions(+), 33 deletions(-) diff --git a/Solutions to Chapter 6 Similarity-Based Methods.ipynb b/Solutions to Chapter 6 Similarity-Based Methods.ipynb index 2684080..0f9ff11 100644 --- a/Solutions to Chapter 6 Similarity-Based Methods.ipynb +++ b/Solutions to Chapter 6 Similarity-Based Methods.ipynb @@ -1377,31 +1377,6 @@ "#### Problem 6.14 (a) Prepare Zip Code Data" ] }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "def split_zip_data(zip_data_path, splits = 1, train_size = 500):\n", - " # Split the raw data into train and test\n", - " # splits: specify the number of random splits for each train-test pair\n", - " X_tr, y_tr, X_te, y_te = data.load_zip_data(zip_data_path)\n", - " train_size = train_size\n", - " splits = splits\n", - " data_splits = data.sample_zip_data(X_tr, y_tr, train_size, splits)\n", - " return data_splits\n", - "\n", - "def set_two_classes(y_train, y_test, digit): \n", - " # Classify digit '1' vs. not '1'\n", - " y_train[y_train==digit] = 1\n", - " y_test[y_test==digit] = 1\n", - " \n", - " y_train[y_train!=digit] = -1\n", - " y_test[y_test!=digit] = -1\n", - " return y_train, y_test" - ] - }, { "cell_type": "code", "execution_count": 3, @@ -1419,7 +1394,7 @@ ], "source": [ "zip_data_path = './data/usps.h5'\n", - "data_splits = split_zip_data(zip_data_path, splits = 1)\n", + "data_splits = data.split_zip_data(zip_data_path, splits = 1)\n", "\n", "X_train, y_train, X_test, y_test = data_splits[0]\n", "\n", @@ -1428,7 +1403,7 @@ "freqs = counts/len(y_train)\n", "print('Frequencies of the digits: \\n', dict(zip(unique, freqs)))\n", "\n", - "y_train, y_test = set_two_classes(y_train, y_test, 1)" + "y_train, y_test = data.set_two_classes(y_train, y_test, 1)" ] }, { @@ -1692,7 +1667,7 @@ "k=3\n", "tot_exps = 1000 \n", "zip_data_path = './data/usps.h5'\n", - "data_splits = split_zip_data(zip_data_path, splits = tot_exps)\n", + "data_splits = data.split_zip_data(zip_data_path, splits = tot_exps)\n", "digit = 1 #we classify digit '1' vs. non '1'\n", "nn_Eins, nn_Eouts = [], []\n", "cnn_Eins, cnn_Eouts = [], []\n", @@ -1700,7 +1675,7 @@ " if (it + 100) % 100 == 0:\n", " print('---- Working on iteration: ', it)\n", " X_train, y_train, X_test, y_test = data_splits[it]\n", - " y_train, y_test = set_two_classes(y_train, y_test, digit)\n", + " y_train, y_test = data.set_two_classes(y_train, y_test, digit)\n", " X_tr, X_te = data.compute_features(X_train, X_test)\n", " \n", " nn_cls = nn.NearestNeighbors(X_tr, y_train, k)\n", diff --git a/Solutions to Chapter 9 Learning Aides.ipynb b/Solutions to Chapter 9 Learning Aides.ipynb index a32e22e..d0c22ac 100644 --- a/Solutions to Chapter 9 Learning Aides.ipynb +++ b/Solutions to Chapter 9 Learning Aides.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -365,7 +365,143 @@ "\n", "#### Exercise 9.18 TODO\n", "\n", - "#### Problem 9.1 TODO" + "#### Problem 9.1" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAEICAYAAABLdt/UAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAHsJJREFUeJzt3X+8VXWd7/HXG0QYBCQJKQ/n6HEq0yhRsOxSJ3/0UPxRWY/RwaburXtmgHvK7EZTHa2mpjxnujNSRvIoJrNmbKK4peMjzPSW4miNBYaKQzomJmQJaSRIMAGf+8dam7PYHs6vvc5e+8f7+Xicxzln73W+67s3m/f+7u9an+9SRGBmZo1jTNEdMDOzfDnYzcwajIPdzKzBONjNzBqMg93MrME42M3MGoyDvQ5J+oSkGwa4/3FJb6xmn0aDpHMk3TTEbd8n6e9Gu0/NStIXJX1siNt+VdKnB7g/JL0kv95ZOQd7laRh+wdJOyU9Jel6SZOK7lclJI2X9FlJT0r6naTlksYNsP0KSQ9L2i/pXUPYRQ8w1LBeAbxD0tED7P9Tkh6UtFfSJwZrMA2gByWNydz2aUlfHWKfqmKwN/p0m8fT190Rmdv+UtKdQ9lHRCyOiE9V2FWrEgd7db0pIiYBpwKnAR8t30CJevl3+QgwF5gFvIzkcT3vMWXcD3QB9w3WsKTTgCMj4t+H0pGI2A18D/jvA2z2KPAhYPVQ2kwdAywYxvYjIumw0d4HcBhweRX2UxVVes7qUr0ESEOJiF+RhNAsAEl3SrpK0j3ALuB4ScdIulnSM5IelfRXZc1MkPRNSTsk3Sfp5P72JWmMpI9I+oWkpyV9S9JR6X3HpaPSd0vanI66F0s6TdIDkrZL+sIAD+VNwOcj4pmI2AZ8HvifAzzuayPiB8DuITxN5wFryh7LNWk/n5W0TtLry/7mTuCCAfb/tYj4HrBjCPsv+T/AJw8VIpJOl/Sj9Lm6X9IZmfveLWlj+m/0mKRFmfvOkLRF0ocl/Qa4Pr39Qknr0/Z+JOlVmb/5sKRfpe09LOlsSfOBK4A/Tz8N3j/AY/l74IOSph7isbxc0u3pa+5hSZdk7jtoekXShyT9Ov209pf9TK+8QNLqtK/3SvrTst2dnz4nv5X096XBTPp6/aikX0raKumfJB2Z3ld6vXZKegL44QCPtak52AsgqRU4H/hZ5uZ3AguBycAvgW8AW0hGjH8G9Eg6O7P9W4BVwFHAvwA3qf9pkPcBFwFvSNv6HXBt2TavAV4K/DnwOeBK4I3AK4BLJL3hUA8l/cr+PrP0H7FCrwQeLrvtp8Bs+h7zKkkTMvdvBPp9g6vAd4BngXeV3yGphWT0/+m0Tx8Evi1perrJVuBCYArwbuCzkk7NNPGi9O+OBRam930FWARMA74E3KxkyusE4L3AaRExGTgXeDwibiWZsvpmREyKiIEe/1qSN78P9vNYjgBuJ3lejwYuBZZLekU/284HPkDyGnkJyWur3KXAJ4EXkHxSuqrs/reSfNo7leS1XBoQvCv9OhM4HpgElA8u3gCcSPIcWD8c7NV1k6TtwN0ko9GezH1fjYiHImIvyX/41wEfjojdEbEe+DJJ+Jesi4j/GxF/BJYCE4DT+9nnIuDKiNgSEXuATwB/VjYC/VS6n9uA54BvRMTW9JPFvwGnHOLxfA+4XNJ0SS8ieRMBmDjE52MgUykbWUfEDRHxdETsjYirgfHACZlNdgB5vKkctFvgY8DHJY0vu+8dwC0RcUtE7I+I20nC8/y0v6sj4heRWAPcBmQ/ZewH/iYi9kTEH4C/Ar4UEfdGxL6I+Bqwh+TfdV/6eE+SNC4iHo+IX4zg8XwcuCzz5lNyIckbxfXp83sf8G2SQUW5S4Dr09frLpIAL/ediPhJ+nr+OskbctZn0k96T5AMJi5Nb/8LYGlEPBYRO4FuYEHZ6/UTEfFc+pxZPxzs1XVRREyNiGMjoqvshbk58/MxwDMRkQ22XwIt/W0fEfvpG92XOxa4Mf1ov51kVLsPmJHZ5qnMz3/o5/dDHeS9iuRTx3rgR8BNwB9JRqqV+h3Jp5cDJC1JpzZ+nz6WI4EXZjaZDPx+JDuT9FA6lbGzfIonIm4BniD5RJV1LHBx6blN+/Q64MVpm+dJ+vd0amM7SeBn+7stPTaQbW9JWXutwDER8SjwfpI35q2SVkrq7997QBGxAfguyfGR8sfymrJ9/wXJIKPcMRz8et3czza/yfy8i+e/hrJ/80v6XrvHpL9n7zuMg1+v/e3PMhzstSO7zOaTwFGSssHWBvwq83tr6Yd0fnJm+nflNgPnpW8opa8J6Wi8sg5H/CEi3hsRLRFxPPA0ySeJfZW2DTxAckAWgDRsP0wyWnxBREwlCfHsVNCJJAdohy0iXpFOZUyKiH/rZ5OPkkxRZT+NbAb+uey5PSIi/i4d3X8b+AdgRtrfW8r6W7606mbgqrL2JkbEN9I+/ktEvI4khAP4zCHaGczfkHw6KB8orCnb96SI+F/9/P2vSV5vJa39bDOY7N+00ffafZLk8WXv28vBgw0vSTsIB3sNiojNJCPgXkkT0gNonSQfaUvmSHpb+hH1/SQf2fs7g+SLwFWSjgVIp03ekkc/JbUoOcgrSaeTTFn8zQDbH57OiQsYlz62Q70Gb+HgudvJJP/BtwGHSfo4ydx11htIpocOtf9x6f7HpG1MkDR2kIcJQETcCTwI/I/MzTcAb5J0rqSxaXtnSJoJHE4ydbIN2CvpPOCcQXbzj8BiSa9Jn9MjJF0gabKkEySdlb5h7Cb5JFV6A30KOG6A57L8sTwKfJO+qTNIRvEvk/TO9Hkap+Qg+on9NPEt4N2STpQ0kWR6Z7j+WtIL0uNNl6f9geTY0v+W1K7kdODS8YO9I9hH03Kw165LgeNIRjA3kszF3p65/19JDnb+jmTu/W3pfHu5a4Cbgdsk7SAJ/9fk1Mc/JXkDeg74GvCRdJ4eAEnfk3RFZvvbSALpv5Gcd/4HoKO/htM53t9LKvX1+ySh/QjJx/PdZD6Sp4F9ftqPQ/nHdJ+Xkoy+/8DBxy0G81GSg52lPm4mOfB3BUmAbwb+GhiTTqO9jyQEfwe8neTf4ZAiYi3JSPoL6d88St9B2/Ek5/T/lmSa4+h0v5AcRAd4WtKgp5Km/hY4cE572t9zSE7tfDLdx2fS/Zb383skZ0Ddkfbxx+lde4a4b0hev+tIpvFWA9elt38F+GfgLmATyb/zZcNo1wD5QhtWqySdA3RFxEVD2PYyoDUiPjT6PbOsdFS/ARjvkXVtcLCb2bBJeivJSPsIkk9J+4fyBmzV4akYMxuJRSTTT78gmevv7yCrFcQjdjOzBuMRu5lZgylkEZ0XTpoUx02bVsSuR8fOnfzXMzvZNvZFHN5SXtBnZpaPJ55Y99uIGDRkCgn246ZNY+2VVxax69Fz1108sfIe+NXT3Lp8U9G9MbMGtGiRfjn4Vp6KyU9HB23Luxk/AaZ/4B1F98bMmpiDPWczlnYzZ/c9zO9q5667iu6NmTUjB/soaFvezdSp0LOyndMuy6vI08xsaGrmCiR/POwwthx/PLsn5rHia3Em7NrFzMceY0pPd7KQSVcvvb3Q3V10z8ysWdRMsG85/ngmt7Zy3OTJSBr8D2pQRPD0jh1sAdofeQSAtlZYs7kduvBBVTOripqZitk9cSLT6jjUASQxbfLkgz91dHfTtjwZrs/vaqe3t6DOmVnTqJlgB+o61EsO9RjalnfTNudo1mz2QVUzG101FewNr7OT8ROSg6qPX7Gi6N6YWYNysA/Rqptv5hWvfz1jZsxg7fr1I25nxtJu2hbMY/H2XuZ3tefYQzOzhIO9H3fecw/vuuzgtf1nvfzlfOf66+l47Wsr30GmmMnhbmZ5q99g37cPbrsNrr46+b4vj8tsHtqJL3sZJ7zkJbm2OWNp30FVz7ubWV5q5nTHYdm3Dy65BO67D3btgokT4dRT4VvfgrFDuoRlzWhb3s2zV/TSs7Kdp1cdzU+X3Vt0l8ysztVnsP/gB0moP/dc8vtzzyW//+AHcM5g1ws+tNfMn8+ePXvY+dxzPLN9O7PPPBOAz3zsY5x71ll59Lxf2WKm+V3tPt/dzCpSn1MxDz6YjNSzdu2CDRsqavbeW29l/R138OXPfpY3n3su6++4g/V33DGqoZ6VPd/dzGyk6jPYX/nKZPola+JEmDWrmP7kyMVMZlap+gz2s89O5tSPOAKk5Puppya3j5IbV69m5skn8+O1a7ng7W/n3EsuGbV9ZYuZvASwmQ1XIdc8nXvssVF+oY2Ns2dzYvswpiD27Uvm1DdsSEbqZ59dMwdON27axIkVnOt+QHrxji9O7ea4noWVt2dmdW3RIq2LiLmDbVefB08hCfFzzqnoYGnN6+igDVi8she6en1Q1cyGpD6nYppJWTGTlyIws8E42OvEjKXdtLXC4u29LmYyswE52OtJd9+VmXxQ1cwOpX7n2JtUtpiJrnYunrOJzs6ie2VmtcQj9jpVuq7qF9f7mqpmdjAHez9+/p//yWvPO4/xM2fyD9deW3R3DmlKTzfT9m11MZOZHcTB3o+jpk7l8z09fLCrq+iuDKptefeB66p6KQIzAwd7v46ePp3TTjmFcYfVySGI9LqqY8fi0yHNLJ9gl/S4pAclrZe0No82h+KMi47kjIuOrNbual7L7KN9ZSYzy3XEfmZEzB5KuWsezrjoSNb86HDW/OhwB3xJZ6eLmcwGcPXVyVej81RM6trrrmP2mWcy+8wzefI3vym6OxXJXlfVxUxmiauvhkceSb4aPeDzmkQO4DZJAXwpIp43VJS0EFgI0HbUURXv8M6bfn9glH7nTb+vuL33dHbynkY6Ibyjg6m33kPPynbW3TSPbUtvKLpHZlYleQX7vIh4UtLRwO2Sfh4RB40V07BfAcnqjjntd1T85qmnmHvOOTy7YwdjxozhcytW8B93382UyZOL7tqwlBczXbFgEx0dRffKrBhLlvSN0pcsKbYvoy2XYI+IJ9PvWyXdCLwaGPVJgDxG6v150YwZbLn//lFpuwjZ66re2uEVIs0aXcVz7JKOkDS59DNwDlDZNeosd1N6+q7MdN11BXfGrCBLljT+aB3yOXg6A7hb0v3AT4DVEXFrDu1azkrFTKvWuZjJrJFVHOwR8VhEnJx+vSIirqqgrUq7U7iafwxpMRPAaZd5nRmzRlQzpztO2LWLp3fsqP1gHEBE8PSOHUzYtavorgyqbcG8A+vMmFljqZma+ZmPPcYWYNvEiUV3pSITdu1i5mOPFd2NwXV00NbRwa8uSypVfV1Vs8ZRM8E+bu9e2h95pOhuNJ2WZd1w110sXtnLF6/A4W7WAGpmKsYK1NFB25xknRlfmcms/jnYLZGuMzNn9z3M72r3UgRmdczBbgcpXZmpZ6Uv3mFWrxzs9jxTepJwX7PZxUxm9cjBbv2a0uNiJrN65WC3Q8sUM3kpArP64WC3QZUu3rFqnUfuZvXAwW5DMmNpck1VX5nJrPY52G3IWpb1XZnJ8+5mtcvBbsPT0XFg3t3FTGa1ycFuI+JiJrPa5WC3EcsWM3kJYLPa4WC3ikzpSU6JnLZvq0+HNKsRDnbLhYuZzGqHg93yUVbM5HVmzIrjYLdctS3vpm3O0azZ7JG7WVEc7Ja/zk4XM5kVyMFuo8LFTGbFcbDb6EmLmcZPcDGTWTXlFuySxkr6maTv5tWmNYYZS13MZFZNeY7YLwc25tieNZBsMZNH72ajK5dglzQTuAD4ch7tWWMqFTPN2X2PT4c0G0V5jdg/B3wI2H+oDSQtlLRW0tptO3fmtFurR22tyWX3vAyB2eioONglXQhsjYh1A20XESsiYm5EzJ0+aVKlu7V61t23DIGLmczyl8eIfR7wZkmPAyuBsyTdkEO71uDalncfGL17nRmz/FQc7BHRHREzI+I4YAHww4jw0TEbmu6+y+65mMksHz6P3Qo3Y2myDIGLmczykWuwR8SdEXFhnm1ak+jsPFDM5PPdzSrjEbvVlNJFs3tWOtzNRsrBbjWnZZmLmcwq4WC3mpQtZvK8u9nwONitppUu3uFiJrOhc7BbzcsWM/l8d7PBOditLpSKmVat81IEZoNxsFv96E4u3jFt31YvQ2A2AAe71ZeOjgPXVPVBVbP+Odit/qTFTL6uqln/HOxWt1qW9S1F4GImsz4OdqtvnZ2Mn+BiJrMsB7vVvRlLDy5m8ujdmp2D3RpG6bqqS1b5dEhrbg52ayhTelzMZOZgt4aTLWbyKZHWjBzs1pjS66oCLmaypuNgt4bmYiZrRg52a2wuZrIm5GC3ptCyLFlnxsVM1gwc7NY8OjpczGRNwcFuTcXFTNYMHOzWlErFTD0rfVDVGk/FwS5pgqSfSLpf0kOSPplHx8xG25Se5HRIFzNZo8ljxL4HOCsiTgZmA/MlnZ5Du2ajzsVM1ogqDvZI7Ex/HZd+RaXtmlVNppjJl92zRpDLHLuksZLWA1uB2yPi3n62WShpraS123bufH4jZgXLXjTbrJ7lEuwRsS8iZgMzgVdLmtXPNisiYm5EzJ0+aVIeuzXLnYuZrBHkelZMRGwH7gTm59muWTVli5l8vrvVozzOipkuaWr6858AbwR+Xmm7ZoVKL5o9Z/c9DnerO3mM2F8M3CHpAeCnJHPs382hXbNipevMuJjJ6k0eZ8U8EBGnRMSrImJWRPxtHh0zqxXZYibPu1s9cOWp2RBM6UnOd1+8vdfFTFbzHOxmQ9XdV8zk892tljnYzYYjLWbydVWtljnYzUagbXk34ycko3ezWuNgNxuhGUtdzGS1ycFuVoGWZd20zTmaxdt7vRSB1QwHu1mlMtdVdTGT1QIHu1lOWi6e52ImqwkOdrO8dHQcVMzk0bsVxcFulrMpPX3XVfXpkFYEB7vZKHExkxXFwW42WlzMZAVxsJuNsux1Vc2qwcFuVg3dfcVMvb1Fd8YanYPdrEpKxUxrNre7mMlGlYPdrJpczGRV4GA3K4CLmWw0OdjNipAWM42f4GImy5+D3axAM5b2FTN5hUjLi4PdrAaULrvnYibLg4PdrBa4mMly5GA3qyHZYiaHu41UxcEuqVXSHZI2SnpI0uV5dMysaXUnK0SuWudiJhuZPEbse4ElEXEicDrwHkkn5dCuWdOa0uNiJhu5ioM9In4dEfelP+8ANgItlbZr1vQyxUw+392GI9c5dknHAacA9/Zz30JJayWt3bZzZ567NWtoLcuScO9Z6ZG7DU1uwS5pEvBt4P0R8Wz5/RGxIiLmRsTc6ZMm5bVbs6bQsiwpZprf5WImG9xheTQiaRxJqH89Ir6TR5tmdrAZS7uTH7p6oaudW5dvKrZDVrPyOCtGwHXAxohYWnmXzGwgbcuTgHcxkx1KHlMx84B3AmdJWp9+nZ9Du2Z2CC5msoHkcVbM3RGhiHhVRMxOv27Jo3NmdmjZYiaP3i3Lladm9SyzFIGLmazEwW7WAFzMZFkOdrNGUFbM5CWAm5uD3ayBlK6runi752WamYPdrNF0drqYqck52M0a0Iyl3bQt8HVVm5WD3axRpddVnToVlqzy6ZDNxMFu1uCm9LiYqdk42M2agIuZmouD3axZuJipaTjYzZqMi5kan4PdrNm4mKnhOdjNmlS2mMmnQzYWB7tZM0uLmXpWupipkTjYzZqci5kaj4PdzA4UM5VG71bfHOxmdkDpuqouZqpvDnYzO4iLmeqfg93Mni9TzOTz3euPg93MDqlted/UjNUPB7uZDShbzOSlCOqDg93MBlUqZlqz2fPu9SCXYJf0FUlbJW3Io72atX8/PPAArF6dfN+/v+gemVVPZydtC+Yxbd9WL0NQ4w7LqZ2vAl8A/imn9mrP/v1wzTWwaRPs2QPjx0N7O1x+OYzxBx9rEh0dtAGLV/ZCVy9XLNhER0fRnbJyuSRSRNwFPJNHWzVrw4a+UIfk+6ZNye1mzaSsmMmj99pTtaGmpIWS1kpau23nzmrtNj+bN/eFesmePcntZk1oxtLkfHcvIlZ7qhbsEbEiIuZGxNzpkyZVa7f5aW1Npl+yxo9PbjdrVt3JNVV7Vvqgai3x5PBQzZqVzKmXwr00xz5rVrH9MivYlJ6Di5m8FEHx8jp42vjGjEkOlG7YkEy/tLYmoe4Dp2ZAcr77s1f0smpdO7d2biq6O00tr9MdvwH8GDhB0hZJnXm0W3PGjIFXvQouuCD57lA3O8iUnr5KVRczFSevs2IujYgXR8S4iJgZEf4wZtak2pZ3+7qqBfOQ08zyl7muqk+HrD4Hu5mNmpaL57F4e69H7lXmYDez0ZMpZprf5euqVouD3cxG3YylySmRc3bf42KmKnCwm1nVuJipOhzsZlY1LmaqDge7mVVd9rqqlj8Hu5kVo9vFTKPFwW5mhXEx0+hwsJtZsTLFTD6omg8Hu5nVhJaL5x04qGqVcbCbWW1wMVNuHOxmVlOyxUzzu9pd0DQCDnYzq0lty/uuzuTR+/D4QhtmVrOm9HQzBaCrF7rauXjOJjob82oPufKI3cxqXmn0vmqdlyMYCge7mdUFL0cwdA52M6sr2eUIfGpk/xzsZlZ/upPRO3hJgv442M2sbnlJgv452M2svmWWJPDoPeFgN7OG0LLMo/eSXIJd0nxJD0t6VNJH8mjTzGzYykbvj1+xougeFaLiYJc0FrgWOA84CbhU0kmVtluT9u+HBx6A1auT7/v3F90jM+tHy7Ju2hbMY/H23qYcvedRefpq4NGIeAxA0krgLcB/5NB27di/H665BjZtgj17YPx4aG+Hyy+HMZ7RMqs5HR20dXTw1AeScF83YR7blt5QdK+qIo9EagE2Z37fkt7WWDZs6At1SL5v2pTcbmY1qxkXFcsj2NXPbfG8jaSFktZKWrtt584cdltlmzf3hXrJnj3J7WZW85ppUbE8gn0L0Jr5fSbwZPlGEbEiIuZGxNzpkyblsNsqa21Npl+yxo9PbjezulBalqA0em/UZQnyCPafAi+V1C7pcGABcHMO7daWWbOSOfVSuJfm2GfNKrZfZjZsjb6oWMUHTyNir6T3At8HxgJfiYiHKu5ZrRkzJjlQumFDMv3S2pqEug+cmtWl7JLA8xtsSWBFPG86fNTNPfbYWHvllVXfr5lZv3p7eSI9XHbr8k3F9mUAixZpXUTMHWw7DzfNzBpsUTEHu5lZqlEWFXOwm5llNcCyBA52M7N+lBYVq8dlCRzsZmaHko7ex0+or9G7g93MbBAzltbXomIOdjOzoejoOGj0XsvLEjjYzcyGoR4WFXOwm5mNQHZRsVpblsDBbmY2QqVFxabt21pTi4o52M3MKtS2vJu21tpZVMzBbmaWh+7aGb072M3McpQdvRd1aqSD3cwsbwUvKuZgNzMbJUUtKuZgNzMbTQUsKuZgNzOrgpZl1VuWwMFuZlYtZcsSjNbo3cFuZlZl5YuK5b0sgYPdzKwImdF7z8p8FxVzsJuZFWg0FhVzsJuZ1YA8FxVzsJuZ1Yi8FhWrKNglXSzpIUn7Jc2tpC0zM0tUuqjYYRXufwPwNuBLFbZjZmZZ3d20AXQlZ868oXXTkP+0ohF7RGyMiIcracPMzA6tNHpfs3noRU2VjtiHTNJCYGH66x4tWrShWvseBS8Eflt0Jyrg/hfL/S9WPff/2KFsNGiwS/p/wIv6uevKiPjXofYmIlYAK9I210ZE3c7Ju//Fcv+L5f7XvkGDPSLeWI2OmJlZPny6o5lZg6n0dMe3StoCvBZYLen7Q/zT0V+3cnS5/8Vy/4vl/tc4RUTRfTAzsxx5KsbMrME42M3MGkxhwV6PyxFImi/pYUmPSvpI0f0ZLklfkbRVUt3VEEhqlXSHpI3p6+byovs0HJImSPqJpPvT/n+y6D6NhKSxkn4m6btF92W4JD0u6UFJ6yWtLbo/o6nIEXtpOYKcl5gfHZLGAtcC5wEnAZdKOqnYXg3bV4H5RXdihPYCSyLiROB04D119vzvAc6KiJOB2cB8SacX3KeRuBzYWHQnKnBmRMxu9PPYCwv2OlyO4NXAoxHxWET8F7ASeEvBfRqWiLgLeKbofoxERPw6Iu5Lf95BEi4txfZq6CKxM/11XPpVV2cuSJoJXAB8uei+2MA8xz50LcDmzO9bqKNgaSSSjgNOAe4ttifDk05jrAe2ArdHRF31H/gc8CFgf9EdGaEAbpO0Ll3ipGGN6loxeS1HUCPUz211NeJqBJImAd8G3h8Rzxbdn+GIiH3AbElTgRslzYqIujjeIelCYGtErJN0RtH9GaF5EfGkpKOB2yX9PP0U23BGNdgbbDmCLUBr5veZwJMF9aUpSRpHEupfj4jvFN2fkYqI7ZLuJDneURfBDswD3izpfGACMEXSDRGR34U6R1lEPJl+3yrpRpLp1YYMdk/FDN1PgZdKapd0OLAAuLngPjUNSQKuAzZGxNKi+zNckqanI3Uk/QnwRuDnxfZq6CKiOyJmRsRxJK/9H9ZTqEs6QtLk0s/AOdTPm+qwFXm640iXIyhEROwF3gt8n+TA3bci4qFiezU8kr4B/Bg4QdIWSZ1F92kY5gHvBM5KT1dbn44e68WLgTskPUAySLg9IurulME6NgO4W9L9wE+A1RFxa8F9GjVeUsDMrMF4KsbMrME42M3MGoyD3cyswTjYzcwajIPdzKzBONjNzBqMg93MrMH8f5HIq2FC5SolAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "#### Problem 9.1\n", + "\n", + "df = pd.DataFrame({'x1':[0, 0, 5], 'x2':[0, 1, 5], 'y':[1, 1, -1]})\n", + "xsp1 = df.loc[df['y']==1]['x1'].values\n", + "ysp1 = df.loc[df['y']==1]['x2'].values\n", + "xsm1 = df.loc[df['y']==-1]['x1'].values\n", + "ysm1 = df.loc[df['y']==-1]['x2'].values\n", + "\n", + "#plt.tight_layout()\n", + "X_train = df[['x1', 'x2']].values\n", + "y_train = df['y'].values\n", + "cls = nn.NearestNeighbors(X_train, y_train, 1)\n", + "x1_min, x1_max = -1, 6\n", + "x2_min, x2_max = -1, 6\n", + "xx1, xx2 = myplot.get_grid(x1_min, x1_max, x2_min, x2_max, step=0.02)\n", + "myplot.plot_decision_boundaries(xx1, xx2, 2, cls)\n", + "\n", + "myplot.plt_plot([xsp1, xsm1], [ysp1, ysm1], 'scatter', \n", + " colors = ['r', 'b'], markers = ['o', '+'], labels = ['+1', '-1'], \n", + " title = \"Problem 9.1 (a) 1-Nearest Neighbor\", yscale = None, ylb = -1, yub = 6,\n", + " xlb = -1, xub = 6, xlabel = None, ylabel = None,\n", + " legends = ['+1', '-1'], legendx = None, legendy = None, marker_sizes=[25, 25])" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transformed data points: [[ 0.18441744 -1.40213773]\n", + " [-1.30649561 0.54135868]\n", + " [ 1.12207817 0.86077905]]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAEICAYAAABLdt/UAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAHxJJREFUeJzt3X2YVXW99/H3B0QIATkaYPIgWFbW5CNqSY0PeUjT0rrLo1nn6twUGmV2aacjWqfyHJm6uvVOTa8Op0nrWFIe0zxppreJI1QmGBjkAwYmhAk+oCBKCd/7j98a2LOZGWZmr5m1Z83ndV1zzd5rr73Wdz999m/9fmuvpYjAzMzKY1DRBZiZWb4c7GZmJeNgNzMrGQe7mVnJONjNzErGwW5mVjIO9nZI+oqk6zu5/QlJJ/RlTb1B0nRJt1RcD0lv6GDe90ual/P6O1vfWZLuzHN9Fct+l6RHe2PZ1vckXSTpO12cd0B8tksT7NkL8rKkTZKelnStpBFF11ULSUMl/V9JayU9L+kaSUM6mX+upEclbZP08S6sYg7wta7UEhG3Ag2SDupg3bMl3V41bUUH087owvp+EBHTK+7X4ZdAd0XEfRHxpjyWlRdJn5G0SNIWSdd1Yf4nsvf5HhXTPiFpfm/W2V2SPi5pwS7mmS/pFUkTK6adIOmJrqwjIuZExCdqLLVUShPsmfdFxAjgMOAI4IvVMyjpL4/7QmAq0AC8kfS4dnpMFZYCs4AHd7VgSUcAe0bEb7pRzw3AzA5uawGmSRqcLX8fYAhwWNW0N2Tzll4Watd1cfa1wL8D3+3GKnYDzutuXd0labfeXgfwEvClPlhPn+ij56xD/SXguiUi/gz8nBSIrS2CSyUtBDYD+0vaV9Ktkp6T9LikT1YtZpikH0naKOlBSQe3ty5JgyRdKOmPkp6V9GNJe2W3Tc5amv8kaXXW6j5H0hGSHpK0QdK3Onko7wOujIjnImI9cCXwvzt53FdHxN3AK114mk4C7m1n+nslrZT0jKRvVH0JzgdO7mB5D5CC/JDseiNwD/Bo1bQ/RsTaivudkLXin5d0tSRB25aepNYvgqXZFtk/ZNNPkbQkex5/Vbk1kbVoP589zy9kr+Ww7LZjJa3pyrzZ7V+Q9FS25fSJPLceWkXETyLiFuDZbtztG8DnJY1u70ZJb5Z0V/Yef1TS6RW3nSzpd5JezN6bX6m4rfV9O0PSk8Avs+lvz57nDZKWSjq24j4fz943GyWtUupKOxD4NvCO7HXb0MljuRI4s6PnNfu83iRpfbb8z1bc1qZ7RdI/SvpT9nn8knbuXtld0vezWpdLmlq1uiMk/SF7T15b9V74ZJYXz2X5sW/FbSHp05JWACs6eay9rpTBrrRJ917gdxWTP0ZqbY4E/kRqfa4B9gU+BMyR9O6K+U8FbgT2An4I3KL2u0E+C5wGHJMt63ng6qp5jgIOAP4B+CZwMXAC8FbgdEnHdPRQsr/K6xMk7dnRY++Gt5FCt9oHSFsJh5Geg8ovkoeByZJGVd8pIv4K3E8Kb7L/9wELqqZVt9ZPIW1dHQycDrynnWW33v/giBgRET+SdBipdXs2sDfwH8CtkoZW3PV04ERgCnAQ8PF2Hm+n80o6ETif9Hq9gfQ614tFpC/bz1ffoNRFcxfpvTsWOBO4RtJbs1leAv4RGE36sv6UpNOqFnMMcCDwHknjgdtIWxV7Zeu8SdKYbF1XAidFxEjgaGBJRDwMnAP8Onvd2v0CyvwZ+E/gK+08lkHA/5C2SMcD7wY+J2mn94qktwDXAGcBrwP2zO5T6f3AvOyx3wpUN67OIr0PX0/aUv5ituzjgSbSe+V1pBypHnc6jfR5f0snj7XXlS3Yb8laBQtIrdE5FbddFxHLI+JVYB/gncC/RMQrEbEE+A4p/Fstjoj/joi/AZcDw4C3t7POs4GLI2JNRGwhvTE/pLabYv+WredO0gfqhohYl21Z3Acc2sHj+TlwXvbh2Yf0JQIwvIvPR2dGAxvbmf71bAvhSdKX0JkVt7XO39EH9F52hPi7SI/tvqpp1VsJX4uIDdn67mFH635XPgn8R0TcHxFbI+J7wBbavkZXRsTaiHiOFAydLbujeU8Hrs3eO5uBr3axvr7yr8C5ksZUTT8FeCIiro2IVyPiQeAmUiOGiJgfEb+PiG0R8RCpoVP9pfWViHgpIl4GPgrcHhG3Z/e5i/TF8t5s3m2kMZjXRMRTEbG8B4+lCXhfxZdPqyOAMRFxSUT8NSJWkr4E2hur+RDwPxGxIGts/CtQfUCsBdnj2Ar8F6lRUelbEbE6ey9cyo7PwFnAdyPiweyzPpu0NTK58jFkn5+Xu/PA81a2YD8tIkZHxH4RMavqyV1dcXlf4LmIqAy2P9H2m337/BGxjR2t+2r7ATdnm6cbSK3arcC4inmerrj8cjvXOxrkvZS01bEE+BVwC/A3YF0H83fH86Stl2qVz9OfaPuYW+fvaJO6BXinpL8jfRBXkOo+OpvWwM4t9r9UXN5Mx89Ftf2AC1qf9+y5n1hVb3eW3dG8+9L2Oam8vBOlAe7Weq4BPlJR40OdP6QOl/nzrCtjk6SzKm+LiGXAz0jjMZX2A46qen7OIjVqkHSUpHuyro0XSC3r11YtY3XV8j5ctbx3Aq+LiJdIW6PnAE9Juk3Sm7v7OLPuxm8Bl7TzWPatWvdFtP2MtWrzemVfxtXdW9Wv9bCqhlhHn4F9s+uty96ULbvd3ChS2YK9M5Xf2muBvSRVBtsk0uZgq8oR+kHAhOx+1VaTNkFHV/wNy1rjtRUc8XJEfCYixkfE/qQ30eKspVGrh0ibmdUmVlyeRNvHfCCpFfhiB8v8NWnTdyawECCbd202bW1ErKqx7largUurnvfhEXFDTstv9RTptW81saMZAbIGxeis22EW8MOK+trdo2hXIuKkrCtjRET8oJ1ZvkzagqkOmHurnp8REfGp7PYfkrohJkbEnqS+cNFWVC3vv6qWt0dEfC2r8RcR8fekLopHSC3q6mV0xTeA44DDq9a9qmrdIyPive3cv83rJek1pK667ujoM7CW9CXTuuw9smVXftbr4nC5AynYt4uI1aSWZJOkYUqDbjOAyg/N4ZI+mH2Tf460md/eHiTfBi6VtB9A1m1yah51ShqfDRpJ0ttJew18uZP5d88GegQMyR5bR6/x7bTfX/zPkv4uG6c4D/hRxW3HkLqH2pVtIS0i9UnfV3HTgmxaLXvDPA3sX3H9P4FzspanJO2hNCDY3lZILX4M/JOkAyUNJ23a507SbtlrNxgYnL12XdqzIiIeJ71On62Y/DPgjZI+JmlI9neE0oAmpK2v5yLiFUlHAh/ZxWquJ3WTvEdSa33HSpogaZzS7xz2IH1ONpG2WiG9bhMk7d7Fx7IBuAz4QsXk3wIvSvoXSa/J1t+gtGdXtf/O6jw6W+dX2fkLa1c+nT2uvUhbBq2fgR+S3guHZGM5c4D7I+KJbi6/1w3IYM+cCUwmfQvfDHw56zds9VPS5uXzpL73D2b97dWuILV87pS0kRT+R+VU4+tJX0AvAd8DLsz66YHtm+gXVcx/J6lr52hgbna5kXZkfa4vSKqu9afAYlL3z21Ac8VtZ5IGKTtzL2mwrnLf5fuyabUE+1eA72Wb4qdHxCJSK/VbpNfocTofHO2RiPg5aWDwnmwdv85u2pLzqr5Ier0uJPVnv0znu7ZWuwTYvk971s04ndQPvZbU/fB1oHVweRZwSfae/VfSF1iHssbQqaSgW09qRf8zKUMGARdk63mO1ACYld31l8By4C+SnuniY7mCHV8MZFuo7yONe6wCniGNie20E0HWt38uaVDzKdK40Dq693r9kPRZWpn9/Xu27LtJjaubsmW/nvb7+Qun8Ik2BixJ04FZEVG9N0R7874P+FhEnL6recssa/EuA4ZGGoi3Oqb0I8UNwAE5dgPWPQe72S5I+gBp62UP0pbTtq58GVoxskbI3aQumMtIW9CHxQAKu4HcFWPWVWeTuh/+SOoi+FTns1vBTiV1C60l/X7kjIEU6uAWu5lZ6bjFbmZWMoUcqOa1I0bE5L27u2upWWbTJjY//wovTTxw1/OalciTTy5+JiKqf2W8k0KCffLee7Po4ouLWLWVQUsLv7txBQ9cfH/RlZj1qbPP1p92PZe7Yqyf2nvrOsac/9GiyzCrS4UeM9isRxobmQQ8e+MK1hddi1kdcovdzKxk6qbF/rfddmPN/vvzyvA8jkhbnGGbNzNh5UqGvOofJZpZMeom2Nfsvz8jJ05k8siRSN09Zk99iAie3biRNcCUxx4ruhwzG6DqpivmleHD2bsfhzqAJPYeObLfb3X0Fx5ANWtf3QQ70K9DvVUZHkO/0NjIpMPHFl2FWV2qq2A3M7PaOdi76MZbb+Wt73oXg8aNY9GSJUWXY2bWIQd7O+YvXMjHzz23zbSGN7+Zn1x7LY3veEdBVVl7Dn9lIS21nL7DrIT6b7Bv3Qp33gmXXZb+b83jNKAdO/CNb+RNb3hDr67DumnGDEaPhv91iwdQzSrVze6O3bJ1K5x+Ojz4IGzeDMOHw2GHwY9/DIMHF12d9aFRrx+bTrxmZtv1z2C/++4U6i+9lK6/9FK6fvfdMH16jxd71IknsmXLFja99BLPbdjAIccdB8DXv/Ql3nP88XlUbmbW6/pnsP/+96mlXmnzZli2rKZgv/+OO4DUx37dvHlcd9VVtVRpZlaI/tnH/ra3pe6XSsOHQ0NDMfVYoTyAatZW/wz2d7879anvsQdI6f9hh6XpveTm225jwsEH8+tFizj5Ix/hPaef3mvrsm7IBlAn3TG36ErM6kb/7IoZPDgNlN59d+p+aWhIoZ7TwOmx06Zx7LRpbaZ94OST+cDJJ+eyfMvXqNeP5YjlLaxnZtGlmNWF/hnskEJ8+vSa+tTNzMqof3bFmJlZhxzs1v8dcACHv7KQJy5yP7sZONitDBobmTSx6CLM6oeD3cysZBzsZmYl42BvxyMrVvCOk05i6IQJ/J+rry66HDOzbnGwt2Ov0aO5cs4cPj9rVtGlWFdNm8Y5G5o8gGqGg71dY8eM4YhDD2XIbv13N/8BxwOoZtvlklySngA2AluBVyNiah7L3ZVjT9sTgPm3vNAXqzMz6xfybJIeFxHP5Li8Th172p7c+6vdt18GB7yZde6yy9L/Cy4oto7e5q6YzNXNzRxy3HEcctxxrP3LX4oux8xydtll8Nhj6e+yy3aEfBnlFewB3ClpsaR2j8QkaaakRZIWrd+0qeYVzr/lBY45+q8cc/RfmX/LCzW31j89YwZL7rmHJffcw7777FNzfVaAbAC1qanoQsyKlVdXzLSIWCtpLHCXpEcios0RsiNiLjAXYOp++0VO6+0Vf3n6aaZOn86LGzcyaNAgvjl3Ln9YsIBRI0cWXZp1prGRSQsXFl2F1akLLhg4XTG5BHtErM3+r5N0M3Ak0OunPuitPvV9xo1jzdKlvbJsM7PeVnOwS9oDGBQRG7PL04FLaq7MzCxnZW+pt8qjj30csEDSUuC3wG0RcUcOyzXrkXtXT3E/uw1oNQd7RKyMiIOzv7dGxKU1LKvWcgpXhsfQr82ezejRRRdhVqy62d1x2ObNPLtxY78Oxojg2Y0bGbZ5c9GlmNkAVje/mZ+wciVrgPXDhxddSk2Gbd7MhJUriy7DzAawugn2Ia++ypTHHiu6DDOzfq9uumLM8uQBVBvIHOxWOqPmeADVBjYHu5lZyTjYzcxKxsFupXXv6ilFl2BWCAe7ldKoObMZPBiam4uuxKzvOdittHwwThuoHOxmZiXjYDczKxkHu5XajYs9gGoDj4PdSmvUnNmAB1Bt4HGwW6kNHVZ0BWZ9z8FuZlYyDnYzs5JxsFvpeQDVBhoHu5XauMs9gGoDj4PdSs8DqDbQONjNzErGwW5mVjK5BbukwZJ+J+lneS3TLA/jxqQB1JaWoisx6xt5ttjPAx7OcXlm+ZidDuFrNlDkEuySJgAnA9/JY3lmZtZzebXYvwl8AdjW0QySZkpaJGnR+k2bclqtmZlVqznYJZ0CrIuIxZ3NFxFzI2JqREwdM2JEras165bdhsCKFUVXYdY38mixTwPeL+kJYB5wvKTrc1iuWW48gGoDSc3BHhGzI2JCREwGzgB+GREfrbkyszx5ANUGEO/HbmZWMrvlubCImA/Mz3OZZmbWPW6x24AyZ5772a38HOw2YIy/yv3sNjA42M3MSsbBbmZWMg52M7OScbDbgDNnnk+VZ+XmYLcBZfxVs4suwazXOdjNzErGwW5mVjIOdjOzknGw24AzeDCcOMsDqFZeDnYbcDyAamXnYDczKxkHu5lZyTjYbcB64qK5RZdg1isc7DYgTTpjGudsaCq6DLNe4WC3gamxsegKzHqNg93MrGQc7GZmJeNgtwHNA6hWRg52G7A8gGpl5WC3gcsDqFZSNQe7pGGSfitpqaTlkr6aR2FmZtYzu+WwjC3A8RGxSdIQYIGkn0fEb3JYtpmZdVPNLfZINmVXh2R/UetyzfqCj/RoZZRLH7ukwZKWAOuAuyLi/nbmmSlpkaRF6zdt2nkhZgXwkR6tjHIJ9ojYGhGHABOAIyU1tDPP3IiYGhFTx4wYkcdqzcysHbnuFRMRG4D5wIl5LtfMzLouj71ixkganV1+DXAC8EityzUzs57Jo8X+OuAeSQ8BD5D62H+Ww3LN+oQHUK1s8tgr5qGIODQiDoqIhoi4JI/CzPqKB1CtbPzLUzOzknGwm5mVjIPdzKxkHOxmeADVysXBboYHUK1cHOxmZiXjYDczKxkHu1mFI849qugSzGrmYDfLTDpjWtElmOXCwW5WYe+t64ouwaxmeZxByawcGhsZesvCNrs9fnv0bCbPmVlgUWbd52A3qzDu8ordHltaOGdeE8xqAhzy1n842M060tjIpMbGdLmpiU+vbWJrFvIAx0xcxbRp0DqLWb1wsJt1xezZjK+83tzMgiVT2DoPmAfPDh4LwANX7XRWSLM+52A364kZM3YEfUsLk4An5+3onz9m4ioAxo6FGTMKqdAGMAe7Wa2yvpjKbpulG1PAb1gNLE4terfmra842M3yNns2o7KLrf+Z1bS9Nf/t0WmA9skTZ7p/3nqFIqLPVzp1v/1i0cUX9/l6zQrX3Azr1vHiRtiwIU1aPGwaN512PeCBWOvc2WdrcURM3dV8brGb9aWsw30UO1rzg89t4sgbp7B1K9sHYi/78P0OeesxB7tZwdocMrilhaG3LGTOvCkwL7Xmzx+TWvOzfWRh6yIHu1k9aWxkXGtTvaWF0XcsZOnGKanbZlaafMc1qworz/oHB7tZvWpsZFQW8q3dNk+f37TTIQ9uGDnTrXlro+bBU0kTge8D+wDbgLkRcUVn9/HgqVkOmpsBeHLxjgOXXXRGas27f76c+nLw9FXggoh4UNJIYLGkuyLiDzks28w6kg3ETsp+APX0+U18vWoQ9oP73u/W/ABUc7BHxFPAU9nljZIeBsYDDnazPlR9ALNJCxdy7+op2/vmHfQDR6597JImA4cCO/3ETtJMYCbApL32ynO1ZlatsTEdxKxi0uiLmtoEPcCHD1/lQx6UUG7BLmkEcBPwuYh4sfr2iJgLzIXUx57Xes2sa0bN2fGLWIAXL2rixsVTth/y4MaRM3xY4pLIJdglDSGF+g8i4id5LNPMeldl0E9qbmbvxTuOPQ+pNX/AAR6I7Y9qDnZJApqBhyPi8tpLMrM+N2PG9kFYSK35W5dPYctiePbGdEhit+j7jzxa7NOAjwG/l7Qkm3ZRRNyew7LNrABtWvMtLQDsXXE2qWMmrvIgbB3LY6+YBYByqMXM6lHVYYlfvKiJ36yfwpaKQVifNrC++JenZtYt1YOwNDdzTkX//OJh0wBYf/n1fV+cAQ52M6tVZf98czOTWMGfl6xja8XZpMaO3T6r9QEHu5nlJ0vu7acNbG7mN8unwHrY8go+m1QfcbCbWe+ZMYNx1dMqziblc8P2Dge7mfWpSddku9M0NbGKKelsUj43bK4c7GZWjGx/ycqzSVW25mHHaQP9I6nucbCbWd3Y3poHaG5m9B8XcnjF2aS+9ta0p427bTrnYDez+jRjRpvW/NDzs1/DZoOw4LNJdcTBbmb9QpvDEuOzSXWm5jMo9YTPoGRmuWppgYULeXL1jkllbM335RmUzMyKVXX8+T+f28TJ56bW/NatqX/+/DHXD5jWvIPdzEpn/FVVZ5NasYLDFw+ck4w42M2s3Fpb8xUh/vT5O04y0qpMQe9gN7MBp3oglqZynU3KwW5mNnv29v75SU1N7L267dmkLjpjVb/6kZSD3cysUkXIQzr+/NdvnMLWeak1D/DEkAPq+rDEDnYzs060fzaphZDtQ1+PffMOdjOzrmrnbFK3Lq+/s0k52M3Memins0m1tHBOxblhvz06DdL2ddD7l6dmZr2hKYX7n9emH0lBOv78tGn0eCDWvzw1MytS9jPXyrNJLViSBmGpGIjtjePPO9jNzPrCjBk7Qr6lhUnAk/MW9srZpHIJdknfBU4B1kVEQx7LrEvbtsGyZbB6NUycCA0NMGhQ0VWZWX9TNQhLUxNLN6aAz+NsUnm12K8DvgV8P6fl1Z9t2+CKK2DVKtiyBYYOhSlT4LzzHO5mVpvZOwZhOzqbVOtAbFfkEuwR0SJpch7LqlvLlu0IdUj/V61K0w86qNjazKx0qs8mNWddE03Pdu2+fdbUlDRT0iJJi9Zv2tRXq83P6tU7Qr3Vli1puplZb5oxg+4cc7jPgj0i5kbE1IiYOmbEiL5abX4mTkzdL5WGDk3TzczqiDuHu6qhIfWpt4Z7ax97Q3nHis2sf/Lujl01aFAaKPVeMWZW5/La3fEG4FjgtZLWAF+OiOY8ll1XBg1KA6UeLDWzOpbXXjFn5rEcMzOrnfsRzMxKxsFuZlYyDnYzs5JxsJuZlYyD3cysZBzsZmYl42A3MysZB7uZWck42M3MSsbBbmZWMg52M7OScbCbmZWMg93MrGQc7GZmJeNgNzMrGQe7mVnJONjNzErGwW5mVjIOdjOzknGwm5mVjIPdzKxkHOxmZiWTS7BLOlHSo5Iel3RhHss0M7Oe2a3WBUgaDFwN/D2wBnhA0q0R8Ydal113tm2DZctg9WqYOBEaGmCQN3rMrL7UHOzAkcDjEbESQNI84FSgXMG+bRtccQWsWgVbtsDQoTBlCpx3nsPdzOpKHok0HlhdcX1NNq1cli3bEeqQ/q9alaabmdWRPIJd7UyLnWaSZkpaJGnR+k2bclhtH1u9ekeot9qyJU03M6sjeQT7GmBixfUJwNrqmSJibkRMjYipY0aMyGG1fWzixNT9Umno0DTdzKyO5BHsDwAHSJoiaXfgDODWHJZbXxoaUp96a7i39rE3NBRbl5lZlZoHTyPiVUmfAX4BDAa+GxHLa66s3gwalAZKvVeMmdW5PPaKISJuB27PY1l1bdAgOOig9GdmVqfc3DQzKxkHu5lZyTjYzcxKxsFuZlYyDnYzs5JxsJuZlYyD3cysZBzsZmYl42A3MysZB7uZWck42M3MSsbBbmZWMg52M7OScbCbmZWMg93MrGQc7GZmJeNgNzMrGQe7mVnJONjNzErGwW5mVjIOdjOzknGwm5mVjIPdzKxkagp2SR+WtFzSNklT8yrKzMx6rtYW+zLgg0BLDrWYmVkOdqvlzhHxMICkfKoxM7Oa1RTs3SFpJjAzu7pFZ5+9rK/W3QteCzxTdBE1cP3Fcv3F6s/179eVmXYZ7JL+H7BPOzddHBE/7Wo1ETEXmJstc1FE9Ns+eddfLNdfLNdf/3YZ7BFxQl8UYmZm+fDujmZmJVPr7o4fkLQGeAdwm6RfdPGuc2tZbx1w/cVy/cVy/XVOEVF0DWZmliN3xZiZlYyD3cysZAoL9v54OAJJJ0p6VNLjki4sup7ukvRdSesk9bvfEEiaKOkeSQ9n75vziq6pOyQNk/RbSUuz+r9adE09IWmwpN9J+lnRtXSXpCck/V7SEkmLiq6nNxXZYu9XhyOQNBi4GjgJeAtwpqS3FFtVt10HnFh0ET30KnBBRBwIvB34dD97/rcAx0fEwcAhwImS3l5wTT1xHvBw0UXU4LiIOKTs+7EXFuwR8XBEPFrU+nvgSODxiFgZEX8F5gGnFlxTt0REC/Bc0XX0REQ8FREPZpc3ksJlfLFVdV0km7KrQ7K/frXngqQJwMnAd4quxTrnPvauGw+srri+hn4ULGUiaTJwKHB/sZV0T9aNsQRYB9wVEf2qfuCbwBeAbUUX0kMB3ClpcXaIk9Lq1WPF5HU4gjrR3pHO+lWLqwwkjQBuAj4XES8WXU93RMRW4BBJo4GbJTVERL8Y75B0CrAuIhZLOrboenpoWkSslTQWuEvSI9lWbOn0arCX7HAEa4CJFdcnAGsLqmVAkjSEFOo/iIifFF1PT0XEBknzSeMd/SLYgWnA+yW9FxgGjJJ0fUR8tOC6uiwi1mb/10m6mdS9Wspgd1dM1z0AHCBpiqTdgTOAWwuuacBQOjZ0M/BwRFxedD3dJWlM1lJH0muAE4BHiq2q6yJidkRMiIjJpPf+L/tTqEvaQ9LI1svAdPrPl2q3Fbm7Y08PR1CIiHgV+AzwC9LA3Y8jYnmxVXWPpBuAXwNvkrRG0oyia+qGacDHgOOz3dWWZK3H/uJ1wD2SHiI1Eu6KiH63y2A/Ng5YIGkp8Fvgtoi4o+Caeo0PKWBmVjLuijEzKxkHu5lZyTjYzcxKxsFuZlYyDnYzs5JxsJuZlYyD3cysZP4/f0f2mpLlotMAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "#### Problem 9.1 (b)\n", + "Z_train = data.input_whitening(X_train)\n", + "print(f\"Transformed data points: {Z_train}\")\n", + "cls = nn.NearestNeighbors(Z_train, y_train, 1, 'classification')\n", + "x1_min, x1_max = -1, 6\n", + "x2_min, x2_max = -1, 6\n", + "xx1, xx2 = myplot.get_grid(x1_min, x1_max, x2_min, x2_max, step=0.02)\n", + "myplot.plot_decision_boundaries(xx1, xx2, 2, cls, data.input_whitening)\n", + "\n", + "myplot.plt_plot([xsp1, xsm1], [ysp1, ysm1], 'scatter', \n", + " colors = ['r', 'b'], markers = ['o', '+'], labels = ['+1', '-1'], \n", + " title = \"Problem 9.1 (b) Whitening + 1-Nearest Neighbor\", yscale = None, ylb = -1, yub = 6,\n", + " xlb = -1, xub = 6, xlabel = None, ylabel = None,\n", + " legends = ['+1', '-1'], legendx = None, legendy = None, marker_sizes=[25, 25])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transformed data points: [[0. ]\n", + " [0.67507785]\n", + " [7.06412174]]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAEICAYAAABLdt/UAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGfNJREFUeJzt3X+YHFWd7/H3Z5KQECaCQCRABgIq6jpAYKOiESSAMfxQQAV/LF51WSMrSLgXZXFR0b0rq/eKF3TZ6+YBxF0RFq8QuQSQEIwsrIAhhhAIICZAQgiJhMAkkWgy3/3j1CydzkxmerqS7j75vJ6nn+murj71rZ6qT586VdOjiMDMzPLR1ugCzMysXA52M7PMONjNzDLjYDczy4yD3cwsMw52M7PMONj7IOlrkn60leefknTc9qxpW5A0WdKMAcx3rqRvbo+azGoh6euS/nGA835T0pVbeX6FpHeXV11jZBXsRdj+QdJaSc9L+oGk9kbXVQ9JwyX9H0nLJb0o6Z8kDdvK/NMlPS6pW9KnBrCIS4CBBPZ04AxJr+tlmfsV73nPLSStq3h85ADaHzBJ/13SPEl/lPT9fuY9S9LGoo6Xi9dNqXj+tZK+J2lpMc+Tkr4tafeqdu4rtqmhZa7LYNapmH+FpGcljaiYdo6k27dFfYNVvP939jPPfcX2MqZi2kmSHhvIMiLi4og4p95ac5JVsBfeHxHtwOHA24AvV8+gpFXW/UJgAtAJHERary3WqcJDwOeAef01LOltwK4RcV9/80bEK8BtwH/r5blnIqK951ZMPrRi2r/3136NlgFfA/o8oqoyp6jrtcB1wE8ktUvaGZgDvB44DngN8G5gPel9BkDSm0jb0k7A8QMtsgi1fkO6UOs6AYwg/a63qW31YVblFeBvt8Nytovt9J71qVXCrWYR8SwpiDoBJM2R9A1J95J23AMl7SPpZkmri57aZ6qaGSHp3yR1Fb2pQ3tblqQ2SRdK+p2kFyTd0NPjkzSu6MF+uugVvljs8G+TtEDSmn4OI98PfDciVkfEKuC7wF9uZb2viIjZpB2lP8cDv6xal7dKmlW8J89LqtzZ5gAnDqDdLUjaXdKPJa2StETSBZJUPHeWpLsk/XPRq35U0lF9tRURP4mIm4HVtdQQEZuAq4F2YBxwJrAH8KGIeDwiuiNiRUR8NSIqe5mfJK37dcX90g1ynf4XcGFfR6WSOov39UVJiySdUvHcqZIeKt7vpyt/z5LeXBzlfEbSUuDWYvqRku4vttl5kiZWvOYzSkfMXZIWSzpN0mHAZcDRxdHQiq2sy2XApyXt18e6dEj6maTfF+2fVfHcZsMrkv5K0jPFtnaBthxe2VnSdUWtCySNr1rcuyQ9VuwD0yUNr2j77Ir9/EZJexXTRxT7+V9L+h2wcCvrus1lG+ySOoATgN9UTP4EMBUYBTxN2lGXAfsAHwYukXRsxfwnAz8Bdgd+DMxQ78Mg5wKnAO8p2noRuKJqnncAbwQ+QtqILyL1Et8KnC7pPX2tSnGrfDxW0q59rXsNDgYe/6+GpVHAncDtpPV4AzC7Yv5FQK8fbgPwfWAYcADwXuCvgY9XPH8U6WhjD9LQ0AxJrxnksnpV9KLOBF4ClpDe/5kR8YetvKYNOAO4tridJOm1ZdZVh3uBucB51U8U790s4CpgT9KR1tWS3lDM8jLp/d8NOBX4giqGqIAhpG32TcDJksYBM0jb7e6ko8YZSkNZrwX+N3BsRIwCjgQWRsRvitrmFEduY+jbEuBfga/2si5DSB8u/0HaLqcAf9vbPlOE9HeA04GxxW3PqtlOJX3A70bavi+rev5jwDHFuh8GfLFo+wTgK8Xr9wV+z5ZHWCcBf168rmFyDPYZktYA95B6o5dUPHdNRDwSERuBMaTD7r+JiFciYj5wJSn8ezwYEf8vIv5E2lhGAEf0sszPAhdFxLKI2EA6pP6wNj8c+5/Fcu4A1gHXRcTK4sji3+l7Q7gNmCZptNIY5LnF9JEDfD+2Zjegq+LxScCKiLi0qLUrIu6veL4LqPkDpejxfIj0Xq+NiCdJO1Ple700Iv4pIv4UEf9C+sB9X63L6sN7im1iBenD+pSIWEf6EHmun9ceA7wO+CkpWJ4DPlpSXWX4MnB+Lx82p5LC9dqI2BQRvwb+P+n3QETMLvaF7oiYB9xA6phU+mpErC8++D4J3BgRdxavuRV4FJhcMX+npBER8WxELBrEuvw98BFJb6ya/m5gRER8KyL+GBFPAD+g99/D6cBPI+K+Yl/8Mlvm3F0RMas4gvtXoLrHfnlELC+OkP+BFPQAfwFMj4gFxdDkBcCxqjg3AHwjItZsrbOwPeQY7KdExG4RsX9EfK7qDV5acX8fYHVEVAbb06RP4i3mj4huXu3dV9sfuKk4RF1D6tluAvaqmOf5ivt/6OVxXyd5v0E66phPCpYZwJ+AlX3MX4sXSUcvPTqA321l/lGk3m6txpC2tWcqplW/18uqXvM0vb/Xg/HLYpvYMyImRsScYvoLwN79vPaTpF79S5G+MW+rwzGSrqrYDr5DGl5YU9weGEzxxXBKz4noD1U+V4TyXcAXql62P3BUxbLXkEJ976LNiZJ+WQxXvAR8is17tt0RsbyqvTOq2psA7BMRL5JC71xghdLw5huoUbG86aSOUfW6jKta9v8gbVfV9mHz/fZlttxmK4eE1rPlvleZE5Xb4T7F456215COfHrNjEZq6AB/A1R+leVyYHdJoyrCfT/g2Yp5OnruFIfkY4vXVVsK/GVE3Fv9RHEIO/iC0wfTOcUNSVNJRxKb6mm3sIB0QrbHUl7tnfTmLaThklqtALpJ7+/iYlr1ez226jX70ft7XaY7gb8peplbnJMoxq4/CHRXjA8PB3aT9KaIeLz6NRFxJmm4h2IceHxEnFU9Xy0i4ph+Zvkq6UP/8oppS4E7IuL9fbzmBlKn4eqIeEXpJG9lHlR/7etS4MqI+HwfNc4EZkoaSRr7/7+kIbdavz72m8CTwMNVy34sIg4ewOufo2JbKoakaj3K7Ki4X7kdLid9yPS0vSvphHvldtwUX5ebY499QCJiKWln+IfixMchpB3y2orZ/lzSB4shlfOADUBvV5B8H/iGpP0BimGTk8uoU9K+Sid5JekI0hjfxVuZfyelS+AEDCvWra/f861sfvh9CzBG0nlKl1mOkvSOiuffQxoaqklxSHwT6RzGLpJeD0xj8/HJDqWTqEMlnUHaoe7oYx2HFus4BBhSrOOQWusijT+vJl0lc1DxHo+WdHFxruU0YC3wZtLh+njSh9sD9HJ1UD3qWaeIeAS4GTi7YvIM4DBJH5E0rNgujuhZT1Iv9YUi1N9FWtet+SFwmqRjJQ2RtHNxf0yxjZ5YhPoG0nvW0/F4nvS77fMS3ap16blA4IsVk+8BKLbLEcV7dYikw3tp4gbgQ0oXJ+wE/B2pU1GLcyXtLWlP0lVp/1ZMvw74jNJJ6RHAt0jDOls7KdwQO2ywFz5GujpiOSl4Lo6IWRXP/4x0svNF0njwB4vx9mqXk3asOyR1kcL/Hb3MNxivJ30ArSPtXBcW4/QASLpNm1+5cgdpaOddpMPaP5BOTG6hOIx/qSe8iyOX95KuxFkB/BaYVCxnBOlk9A8HuR6fLX4+TRo6uJLNP0TvJp1nWE06QXdqRPQ17PP3xXqdB/xVcf+Lfczbp+Jo6OiKmrqAXwG7kC4X/SSpl/pscbXMimInvgL4xFY+MAej3nX6Gqn3CEAxPPI+4NOkXuzyYhnDiiGls4BvF9vrBaSLBPoUEYtJQzlfJ500fJr04dxG+jD6EmmbeYF0aWhPz/524ClgpaTq4ba+fJuKCwaKfe4E0jb9NLCKdESwxfBlccL2i6T9+dli3V8ifeAM1PXAL0jb/8OkIxAi4hbSmPvNpPdzDJufJ2oaCv+jjR2apMnA5yLilH7m+zzQEREXbIMazgI+HBEt/5e81lyKk8qrSecC+jtRno0dbYzdqhS9/16HPKrm+952KMesbpI+QLrUcwjpBPb9O1Kog4dizCw/p5GGhZaRrlj5i8aWs/15KMbMLDPusZuZZaYhY+x7trfHuD32aMSizawGqxjd6BKswjPPPPj7iOj3l9KQYB+3xx7MveiiRizazGownamNLsEqfPazerr/uTwUY2aWHQe7mVlmHOxmZplpmj9Q+tPQoSw78EBeGVnGt9E2zoj16xm7eDHDNm5sdClmtoNqmmBfduCBjOroYNyoUaTvKGo9EcELXV0sAw544olGl2NmO6imGYp5ZeRI9mjhUAeQxB6jRrX8UYeZtbamCXagpUO9Rw7rYGatramC3czM6udgH6Cf3Hwzbz3ySNr22ou58+c3uhwzsz452Hsx5957+dTnN/8PYJ1vfjM3/uAHHPXOdzaoKjOzgWmaq2JqtmkTzJ4NDz8MBx8Mxx4LQwbz39EG5i0HHdT/TGZmTaA1g33TJjj9dJg3D9avh5Ej4fDD4YYbtmm4m5m1gtYM9tmzU6ivW5cer1uXHs+eDZMnD7rZd0yZwoYNG1i7bh2r16xh/KRJAHzrK1/hfcf094/izcyaQ2sG+8MPp556pfXrYeHCuoL9/ttvB9IY+zXXX8813/N/gzOz1tOaJ08PPjgNv1QaORI6OxtTj5lZE2nNYD/22DSmvssuIKWfhx+epm8jN82cydhDD+VXc+dy4sc/zvtOP32bLcvMrB6tORQzZEg6UTp7dhp+6ews9aqYoydO5OiJEzebduqJJ3LqiSeW0r6Z2bbUmsEOKcQnT65rTN3MLEetORRjZmZ9crCbmWXGwW5mlhkHu5lZZhzsZmaZcbD34rHf/pZ3Hn88w8eO5dtXXNHocszMalLK5Y6SngK6gE3AxoiYUEa7jbL7brvx3UsuYcZttzW6FDOzmpXZY58UEeNbPdQBXjd6NG877DCGDW3dy/zNbOsuvTTdctTSyXX0KbsCMGfGSw2uxMxayaWXwhNPvHof4PzzG1dP2coK9gDukBTAP0fE9OoZJE0FpgLst/vudS/w6FN25Zf/sdN/3QcHvJkZlDcUMzEiDgeOB86WdFT1DBExPSImRMSE0e3tJS22PFdcdRXjJ01i/KRJLF+xotHlmNk2dP75cNBB6Xb++Xn11qGkHntELC9+rpR0E/B24O4y2u7LnBkvldpTP/vMMzn7zDPrbsfMrNHqDnZJuwBtEdFV3J8M/F3dlTXQiuefZ8Lkybzc1UVbWxuXTZ/Oo/fcw2tGjWp0aWZWktx66ZXK6LHvBdwkqae9H0fE7SW0269tNaY+Zq+9WPbQQ9ukbTOzba3uYI+IxcChJdRiZmYl8F+empllpqmCPSIaXULdclgHM2ttTRPsI9av54WurpYOxojgha4uRqxf3+hSzGwH1jR/eTp28WKWAatGjmx0KXUZsX49YxcvbnQZZrYDa5pgH7ZxIwf0/I2vmZkNWtMMxZiZWTkc7GZmmXGwm5llxsFuZpYZB7uZWWYc7GZmmXGwm5llxsFuZpYZB7uZWWYc7GZmmXGwm5llxsFuZpYZB7uZWWYc7GZmmXGwm5llxsFuZpYZB7uZWWZKC3ZJQyT9RtItZbVpZma1K7PHPg1YVGJ7ZmY2CKUEu6SxwInAlWW0Z2Zmg1dWj/0y4AKgu68ZJE2VNFfS3FVr15a0WDMzq1Z3sEs6CVgZEQ9ubb6ImB4REyJiwuj29noXa2ZmfSijxz4R+ICkp4DrgWMk/aiEds3MbBDqDvaI+FJEjI2IccBHgbsi4oy6KzMzs0HxdexmZpkZWmZjETEHmFNmm2ZmVhv32M3MMuNgNzPLjIPdzCwzDnYzs8w42M3MMuNgNzPLjIPdzCwzDnYzs8w42M3MMuNgNzPLjIPdzCwzDnYzs8w42M3MMuNgNzPLjIPdzCwzDnYzs8w42M3MMuNgNzPLjIPdzCwzDnYzs8w42M3MMuNgNzPLjIPdzCwzdQe7pBGSHpD0kKRHJH29jMLMzGxwhpbQxgbgmIhYK2kYcI+k2yLivhLaNjOzGtUd7BERwNri4bDiFvW2a2Zmg1PKGLukIZLmAyuBWRFxfy/zTJU0V9LcVWvXbtmImZmVopRgj4hNETEeGAu8XVJnL/NMj4gJETFhdHt7GYs1M7NelHpVTESsAeYAU8ps18zMBq6Mq2JGS9qtuL8zcBzwWL3tmpnZ4JRxVczewA8lDSF9UNwQEbeU0K6ZmQ1CGVfFLAAOK6EWMzMrgf/y1MwsMw52M7PMONjNzDLjYDczy4yD3cwsMw52M7PMONjNzDLjYDczy4yD3cwsMw52M7PMONjNzDLjYDczy4yD3cwsMw52M7PMONjNzDLjYDczy4yD3cwsMw52M7PMONjNzDLjYDczy4yD3cwsMw52M7PM1B3skjok/ULSIkmPSJpWRmFmZjY4Q0toYyNwfkTMkzQKeFDSrIh4tIS2zcysRnX32CPiuYiYV9zvAhYB+9bbrpmZDU6pY+ySxgGHAff38txUSXMlzV21dm2ZizUzswqlBbukduCnwHkR8XL18xExPSImRMSE0e3tZS3WzMyqlBLskoaRQv3aiLixjDbNzGxwyrgqRsBVwKKI+E79JZmZWT3K6LFPBD4BHCNpfnE7oYR2zcxsEOq+3DEi7gFUQi1mZlYC/+WpmVlmHOxmZplxsJuZZcbBbmaWGQe7mVlmHOxmZplxsJuZZcbBbmaWGQe7mVlmHOxmZplxsJuZZcbBbmaWGQe7mVlmHOxmZplxsJuZZcbBbmaWGQe7mVlmHOxmZplxsJuZZcbBbmaWGQe7mVlmHOxmZplxsJuZZaaUYJd0taSVkhaW0V7T6u6GBQtg5sz0s7u70RWZmW1haEntXAP8I/AvJbXXfLq74fLLYckS2LABhg+HAw6AadOgzQc+ZtY8SkmkiLgbWF1GW01r4cJXQx3SzyVL0nQzsyay3bqakqZKmitp7qq1a7fXYsuzdOmrod5jw4Y03cysiWy3YI+I6RExISImjG5v316LLU9HRxp+qTR8eJpuZtZEPDg8UJ2daUy9J9x7xtg7Oxtbl5lZlbJOnuavrS2dKF24MA2/dHSkUPeJUzNrMqUEu6TrgKOBPSUtAy6OiKvKaLuptLXBIYekm5lZkyol2CPiY2W0Y2Zm9fM4gplZZhzsZmaZcbCbmWXGwW5mlhkHu5lZZhzsZmaZcbCbmWXGwW5mlhkHu5lZZhzsZmaZcbCbmWXGwW5mlhkHu5lZZhzsZmaZcbCbmWXGwW5mlhkHu5lZZhzsZmaZcbCbmWXGwW5mlhkHu5lZZhzsZmaZKSXYJU2R9LikJyVdWEabTam7GxYsgJkz08/u7kZXZGa2haH1NiBpCHAF8F5gGfBrSTdHxKP1tt1Uurvh8sthyRLYsAGGD4cDDoBp06DNBz5m1jzKSKS3A09GxOKI+CNwPXByCe02l4ULXw11SD+XLEnTzcyaSBnBvi+wtOLxsmLaZiRNlTRX0txVa9eWsNjtbOnSV0O9x4YNabqZWRMpI9jVy7TYYkLE9IiYEBETRre3l7DY7ayjIw2/VBo+PE03M2siZQT7MqAy3cYCy0tot7l0dqYx9Z5w7xlj7+xsbF1mZlXqPnkK/Bp4o6QDgGeBjwIfL6Hd5tLWlk6ULlyYhl86OlKo+8SpmTWZuoM9IjZKOgf4OTAEuDoiHqm7smbU1gaHHJJuZmZNqoweOxFxK3BrGW2ZmVl9PI5gZpYZB7uZWWYc7GZmmXGwm5llxsFuZpYZB7uZWWYc7GZmmXGwm5llxsFuZpYZB7uZWWYc7GZmmXGwm5llxsFuZpYZB7uZWWYc7GZmmXGwm5llxsFuZpYZB7uZWWYc7GZmmXGwm5llxsFuZpYZB7uZWWYc7GZmmakr2CWdJukRSd2SJpRVlJmZDV69PfaFwAeBu0uoxczMSjC0nhdHxCIASeVUY2ZmdVNE1N+INAf4QkTM3co8U4GpxcNOUm+/Ve0J/L7RRdTJ69B4rV4/tP46tFr9+0fE6P5m6rfHLulOYEwvT10UET8baDURMR2YXrQ5NyJadky+1esHr0MzaPX6ofXXodXr70u/wR4Rx22PQszMrBy+3NHMLDP1Xu54qqRlwDuBmZJ+PsCXTq9nuU2g1esHr0MzaPX6ofXXodXr71UpJ0/NzKx5eCjGzCwzDnYzs8w0LNhb9esIJE2R9LikJyVd2Oh6aiXpakkrJbXk3xFI6pD0C0mLiu1nWqNrqpWkEZIekPRQsQ5fb3RNgyFpiKTfSLql0bUMhqSnJD0sab6kPv8GpxU1ssfecl9HIGkIcAVwPPBnwMck/Vljq6rZNcCURhdRh43A+RHxFuAI4OwW/B1sAI6JiEOB8cAUSUc0uKbBmAYsanQRdZoUEeNzu5a9YcEeEYsi4vFGLX+Q3g48GRGLI+KPwPXAyQ2uqSYRcTewutF1DFZEPBcR84r7XaRg2bexVdUmkrXFw2HFraWuYpA0FjgRuLLRtdiWPMZem32BpRWPl9FioZITSeOAw4D7G1tJ7YphjPnASmBWRLTaOlwGXAB0N7qQOgRwh6QHi688yUZdXwLWn7K+jqCJ9PZtZy3V08qFpHbgp8B5EfFyo+upVURsAsZL2g24SVJnRLTEeQ9JJwErI+JBSUc3up46TIyI5ZJeB8yS9FhxRNvytmmwZ/h1BMuAjorHY4HlDaplhyVpGCnUr42IGxtdTz0iYk3xJXpTaJ0vxpsIfEDSCcAI4DWSfhQRZzS4rppExPLi50pJN5GGWrMIdg/F1ObXwBslHSBpJ+CjwM0NrmmHovQd0VcBiyLiO42uZzAkjS566kjaGTgOeKyxVQ1cRHwpIsZGxDjSPnBXq4W6pF0kjeq5D0ymdT5Y+9XIyx0H+3UEDRMRG4FzgJ+TTtrdEBGPNLaq2ki6DvgV8CZJyySd2eiaajQR+ARwTHGZ2vyi59hK9gZ+IWkBqbMwKyJa8pLBFrYXcI+kh4AHgJkRcXuDayqNv1LAzCwzHooxM8uMg93MLDMOdjOzzDjYzcwy42A3M8uMg93MLDMOdjOzzPwnyY6hzY3boP8AAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "#### Problem 9.1 (c)\n", + "\n", + "def pca_transformer(X):\n", + " Z, _, _ = data.pca(X, 1)\n", + " return Z\n", + "\n", + "Z_train = pca_transformer(X_train)\n", + "print(f\"Transformed data points: {Z_train}\")\n", + "cls = nn.NearestNeighbors(Z_train, y_train, 1, 'classification')\n", + "x1_min, x1_max = -1, 6\n", + "x2_min, x2_max = -1, 6\n", + "xx1, xx2 = myplot.get_grid(x1_min, x1_max, x2_min, x2_max, step=0.1)\n", + "myplot.plot_decision_boundaries(xx1, xx2, 2, cls, pca_transformer)\n", + "\n", + "myplot.plt_plot([xsp1, xsm1], [ysp1, ysm1], 'scatter', \n", + " colors = ['r', 'b'], markers = ['o', '+'], labels = ['+1', '-1'], \n", + " title = \"Problem 9.1 (c) Top 1 PCA + 1-Nearest Neighbor\", yscale = None, ylb = -1, yub = 6,\n", + " xlb = -1, xub = 6, xlabel = None, ylabel = None,\n", + " legends = ['+1', '-1'], legendx = None, legendy = None, marker_sizes=[25, 25])" ] }, { @@ -465,6 +601,48 @@ "#### Problem 9.6 TODO" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Add lib input sys.path\n", + "import os\n", + "import sys\n", + "import time\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "import sklearn\n", + "import matplotlib.pyplot as plt\n", + "from scipy.optimize import minimize\n", + "import math\n", + "from sklearn.preprocessing import normalize\n", + "from functools import partial\n", + "import h5py\n", + "from scipy.spatial import distance\n", + "\n", + "nb_dir = os.path.split(os.getcwd())[0]\n", + "if nb_dir not in sys.path:\n", + " sys.path.append(nb_dir)\n", + "\n", + "from matplotlib.colors import ListedColormap\n", + "import libs.linear_models as lm\n", + "import libs.data_util as data\n", + "import libs.nn as nn\n", + "import libs.plot as myplot\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/libs/data_util.py b/libs/data_util.py index 1006dca..03d6e17 100644 --- a/libs/data_util.py +++ b/libs/data_util.py @@ -4,6 +4,7 @@ import functools import h5py from sklearn.model_selection import StratifiedShuffleSplit +from scipy.linalg import sqrtm def generate_random_numbers01(N, dim, max_v = 10000): @@ -272,6 +273,25 @@ def sample_zip_data(X, y, train_size, splits): data_indices.append([X_train, y_train, X_test, y_test]) return data_indices +# Deal with ZIP code data +def split_zip_data(zip_data_path, splits = 1, train_size = 500): + # Split the raw data into train and test + # splits: specify the number of random splits for each train-test pair + X_tr, y_tr, X_te, y_te = load_zip_data(zip_data_path) + train_size = train_size + splits = splits + data_splits = sample_zip_data(X_tr, y_tr, train_size, splits) + return data_splits + +def set_two_classes(y_train, y_test, digit): + # Classify digit '1' vs. not '1' + y_train[y_train==digit] = 1 + y_test[y_test==digit] = 1 + + y_train[y_train!=digit] = -1 + y_test[y_test!=digit] = -1 + return y_train, y_test + def calc_image_symmetry(X, img_w, img_h): """We define asymmetry as the average absolute difference between an image and its flipped versions, and symmetry as the negation of asymmetry @@ -319,3 +339,38 @@ def compute_features(X_train, X_test): +# Input Centering +def input_centering(X): + # Make the mean of X to be zero + N, _ = X.shape + mean_x = np.mean(X, axis = 0).reshape(1, -1) + ones = np.ones((N,1)) + Z = X - np.matmul(ones, mean_x) + return Z + +def input_whitening(X): + # Center the data first + N, _ = X.shape + XX = input_centering(X) + COV = np.matmul(XX.transpose(), XX)/N + sqrt_COV = sqrtm(COV) + Z = np.matmul(XX, np.linalg.inv(sqrt_COV)) + return Z + +def pca(X, top_k, center_first = True): + #PAC dimension reduction to top_k + if top_k < 1: + raise ValueError(f"The reduced dimension {top_k} has to be larger than 0") + + N, d = X.shape + if center_first: + XX = input_centering(X) + else: + XX = X + U, S, V = np.linalg.svd(XX) + Vk = V[:, :top_k] + Z = np.matmul(X, Vk) + X_hat = np.matmul(X, Vk) + X_hat = np.matmul(X_hat, Vk.transpose()) + return Z, X_hat, S + diff --git a/libs/nn.py b/libs/nn.py index 0c875c9..acfd42e 100644 --- a/libs/nn.py +++ b/libs/nn.py @@ -65,12 +65,13 @@ def find_nn_idx(x, X, k): return order[:k], distances[order[:k]] class NearestNeighbors: - def __init__(self, X, y, k, problem_type='classification'): + def __init__(self, X, y, k, problem_type='classification', transformer=None): #X: Nxd matrix, where each row corresponds to a data point x in R^d self.X = X self.y = y self.k = k #number of nearest neighbors self.problem_type = problem_type + self.transformer = transformer def find_nn_idx(self, x, k): # Find the indexes of k nearest neighbors for x @@ -101,7 +102,9 @@ def predict_one(self, x): def predict(self, X): # Predict the y for input X: Mxd matrix - + if self.transformer is not None: + X = self.transformer(X) + M, _ = X.shape predicted = [] for idx in np.arange(M):