Skip to content

Commit

Permalink
Added problems 9.1
Browse files Browse the repository at this point in the history
  • Loading branch information
niuers committed Jul 10, 2020
1 parent 3bd8621 commit c97f35d
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 33 deletions.
33 changes: 4 additions & 29 deletions Solutions to Chapter 6 Similarity-Based Methods.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1377,31 +1377,6 @@
"#### Problem 6.14 (a) Prepare Zip Code Data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def split_zip_data(zip_data_path, splits = 1, train_size = 500):\n",
" # Split the raw data into train and test\n",
" # splits: specify the number of random splits for each train-test pair\n",
" X_tr, y_tr, X_te, y_te = data.load_zip_data(zip_data_path)\n",
" train_size = train_size\n",
" splits = splits\n",
" data_splits = data.sample_zip_data(X_tr, y_tr, train_size, splits)\n",
" return data_splits\n",
"\n",
"def set_two_classes(y_train, y_test, digit): \n",
" # Classify digit '1' vs. not '1'\n",
" y_train[y_train==digit] = 1\n",
" y_test[y_test==digit] = 1\n",
" \n",
" y_train[y_train!=digit] = -1\n",
" y_test[y_test!=digit] = -1\n",
" return y_train, y_test"
]
},
{
"cell_type": "code",
"execution_count": 3,
Expand All @@ -1419,7 +1394,7 @@
],
"source": [
"zip_data_path = './data/usps.h5'\n",
"data_splits = split_zip_data(zip_data_path, splits = 1)\n",
"data_splits = data.split_zip_data(zip_data_path, splits = 1)\n",
"\n",
"X_train, y_train, X_test, y_test = data_splits[0]\n",
"\n",
Expand All @@ -1428,7 +1403,7 @@
"freqs = counts/len(y_train)\n",
"print('Frequencies of the digits: \\n', dict(zip(unique, freqs)))\n",
"\n",
"y_train, y_test = set_two_classes(y_train, y_test, 1)"
"y_train, y_test = data.set_two_classes(y_train, y_test, 1)"
]
},
{
Expand Down Expand Up @@ -1692,15 +1667,15 @@
"k=3\n",
"tot_exps = 1000 \n",
"zip_data_path = './data/usps.h5'\n",
"data_splits = split_zip_data(zip_data_path, splits = tot_exps)\n",
"data_splits = data.split_zip_data(zip_data_path, splits = tot_exps)\n",
"digit = 1 #we classify digit '1' vs. non '1'\n",
"nn_Eins, nn_Eouts = [], []\n",
"cnn_Eins, cnn_Eouts = [], []\n",
"for it in range(tot_exps):\n",
" if (it + 100) % 100 == 0:\n",
" print('---- Working on iteration: ', it)\n",
" X_train, y_train, X_test, y_test = data_splits[it]\n",
" y_train, y_test = set_two_classes(y_train, y_test, digit)\n",
" y_train, y_test = data.set_two_classes(y_train, y_test, digit)\n",
" X_tr, X_te = data.compute_features(X_train, X_test)\n",
" \n",
" nn_cls = nn.NearestNeighbors(X_tr, y_train, k)\n",
Expand Down
Loading

0 comments on commit c97f35d

Please sign in to comment.