diff --git a/examples/cifar/analyze.py b/examples/cifar/analyze.py index d549c16..b892538 100644 --- a/examples/cifar/analyze.py +++ b/examples/cifar/analyze.py @@ -125,7 +125,6 @@ def main(): per_device_batch_size=None, factor_args=factor_args, overwrite_output_dir=True, - initial_per_device_batch_size_attempt=8192, ) analyzer.compute_pairwise_scores( scores_name="pairwise", @@ -135,7 +134,6 @@ def main(): train_dataset=train_dataset, per_device_query_batch_size=500, overwrite_output_dir=True, - initial_per_device_train_batch_size_attempt=8192, ) scores = analyzer.load_pairwise_scores("pairwise") print(scores) diff --git a/examples/uci/tutorial.ipynb b/examples/uci/tutorial.ipynb index de9b154..ee64cb6 100644 --- a/examples/uci/tutorial.ipynb +++ b/examples/uci/tutorial.ipynb @@ -1241,7 +1241,7 @@ }, "outputs": [], "source": [ - "random_indices = list(range(summed_scores.shape[0]))\n", + "random_indices = list(range(len(train_dataset)))\n", "shuffle(random_indices)\n", "random_removed_loss_lst = []\n", "\n", @@ -1257,10 +1257,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 41, "id": "9e6fd656-c1e5-42cb-87f4-d7690f473fe9", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Computing pairwise scores (query gradient) [0/1] 0%| [time left: ?, time spent: 00:00]\n", + "Computing pairwise scores (training gradient) [1/1] 100%|███████████████████████████████████████████████████████████████████████████████████████████████ [time left: 00:00, time spent: 00:00]\u001b[A\n", + "Computing pairwise scores (query gradient) [1/1] 100%|██████████████████████████████████████████████████████████████████████████████████████████████████ [time left: 00:00, time spent: 00:00]\n", + "Computing pairwise scores (query gradient) [0/1] 0%| [time left: ?, time spent: 00:00]\n", + "Computing pairwise scores (training gradient) [1/1] 100%|███████████████████████████████████████████████████████████████████████████████████████████████ [time left: 00:00, time spent: 00:00]\u001b[A\n", + "Computing pairwise scores (query gradient) [1/1] 100%|██████████████████████████████████████████████████████████████████████████████████████████████████ [time left: 00:00, time spent: 00:00]\n" + ] + } + ], "source": [ "from kronfluence import FactorArguments\n", "\n", @@ -1317,6 +1330,14 @@ "plt.ylabel(\"Query Loss\")\n", "plt.xlabel(\"Number of Training Samples Removed\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0f7d37b-4d80-480c-ab0b-a9fa0aab2096", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {