Skip to content

Commit

Permalink
correct number of epochs in example
Browse files Browse the repository at this point in the history
  • Loading branch information
KGrewal1 committed Nov 10, 2023
1 parent da6a36b commit ebcbd32
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions examples/mnist/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub fn training_loop<M: Model, O: Optim>(
// load the test labels
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
// loop for model optimisation
for epoch in 1..args.epochs {
for epoch in 0..args.epochs {
// get log probabilities of results
let logits = model.forward(&train_images)?;
// softmax the log probabilities
Expand All @@ -60,7 +60,8 @@ pub fn training_loop<M: Model, O: Optim>(
// get the accuracy on the test set
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
println!(
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
"{:4} train loss: {:8.5} test acc: {:5.2}%",
epoch + 1,
loss.to_scalar::<f32>()?,
100. * test_accuracy
);
Expand Down

0 comments on commit ebcbd32

Please sign in to comment.