Skip to content

Commit

Permalink
fix: pointnet example (#449)
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 authored Nov 3, 2023
1 parent 3f52406 commit 0ea7668
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,9 @@
" range(num_train_examples),\n",
" desc=f\"Training Epoch {epoch}/{config.epochs}\"\n",
" )\n",
" data_iter = iter(train_loader)\n",
" for batch_idx in progress_bar:\n",
" data = next(iter(train_loader)).to(device)\n",
" data = next(data_iter).to(device)\n",
" \n",
" optimizer.zero_grad()\n",
" prediction = model(data)\n",
Expand Down Expand Up @@ -395,8 +396,9 @@
" range(num_val_examples),\n",
" desc=f\"Validation Epoch {epoch}/{config.epochs}\"\n",
" )\n",
" data_iter = iter(val_loader)\n",
" for batch_idx in progress_bar:\n",
" data = next(iter(val_loader)).to(device)\n",
" data = next(data_iter).to(device)\n",
" \n",
" with torch.no_grad():\n",
" prediction = model(data)\n",
Expand Down
6 changes: 4 additions & 2 deletions colabs/pyg/pointnet-classification/03_sweep.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,9 @@
" range(num_train_examples),\n",
" desc=f\"Training Epoch {epoch}/{config.epochs}\"\n",
" )\n",
" data_iter = iter(train_loader)\n",
" for batch_idx in progress_bar:\n",
" data = next(iter(train_loader)).to(device)\n",
" data = next(data_iter).to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" prediction = model(data)\n",
Expand Down Expand Up @@ -315,8 +316,9 @@
" range(num_val_examples),\n",
" desc=f\"Validation Epoch {epoch}/{config.epochs}\"\n",
" )\n",
" data_iter = iter(val_loader)\n",
" for batch_idx in progress_bar:\n",
" data = next(iter(val_loader)).to(device)\n",
" data = next(data_iter).to(device)\n",
"\n",
" with torch.no_grad():\n",
" prediction = model(data)\n",
Expand Down

0 comments on commit 0ea7668

Please sign in to comment.