Skip to content

Commit

Permalink
Fix naming in Train.example_train_result
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Dec 14, 2024
1 parent a0a7f9c commit 7d3eeba
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 2,258 deletions.
10 changes: 5 additions & 5 deletions bin/moons_benchmark.ml
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
outputs;
model_result;
infer_callback;
batch_losses;
epoch_losses;
rev_batch_losses;
rev_epoch_losses;
learning_rates;
used_memory;
} =
Expand Down Expand Up @@ -148,7 +148,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
Line_plot
{
points =
Array.of_list_rev_map batch_losses ~f:Float.(fun x -> max (log 0.00003) (log x));
Array.of_list_rev_map rev_batch_losses ~f:Float.(fun x -> max (log 0.00003) (log x));
pixel = "-";
};
]
Expand All @@ -158,7 +158,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
let plot_loss =
let open PrintBox_utils in
plot ~size:(120, 30) ~x_label:"step" ~y_label:"epoch log loss"
[ Line_plot { points = Array.of_list_rev_map epoch_losses ~f:Float.log; pixel = "-" } ]
[ Line_plot { points = Array.of_list_rev_map rev_epoch_losses ~f:Float.log; pixel = "-" } ]
in
PrintBox_text.output Stdio.stdout plot_loss;
Stdio.printf "\nLearning rate:\n%!";
Expand Down Expand Up @@ -186,7 +186,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
result_label = "init time in sec, min loss, last loss";
result =
[%sexp_of: float * float * float]
(init_time_in_sec, List.reduce_exn epoch_losses ~f:Float.min, List.hd_exn epoch_losses);
(init_time_in_sec, List.reduce_exn rev_epoch_losses ~f:Float.min, List.hd_exn rev_epoch_losses);
}
in
Stdio.printf "\n\n%!";
Expand Down
12 changes: 6 additions & 6 deletions bin/moons_demo_parallel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ let experiment ~seed ~backend_name ~config () =
outputs;
model_result;
infer_callback;
batch_losses;
epoch_losses;
rev_batch_losses;
rev_epoch_losses;
learning_rates;
used_memory;
} =
Expand Down Expand Up @@ -91,14 +91,14 @@ let experiment ~seed ~backend_name ~config () =
let plot_loss =
let open PrintBox_utils in
plot ~size:(120, 30) ~x_label:"step" ~y_label:"batch loss"
[ Line_plot { points = Array.of_list_rev batch_losses; pixel = "-" } ]
[ Line_plot { points = Array.of_list_rev rev_batch_losses; pixel = "-" } ]
in
PrintBox_text.output Stdio.stdout plot_loss;
Stdio.printf "\nEpoch Loss:\n%!";
let plot_loss =
let open PrintBox_utils in
plot ~size:(120, 30) ~x_label:"step" ~y_label:"epoch loss"
[ Line_plot { points = Array.of_list_rev epoch_losses; pixel = "-" } ]
[ Line_plot { points = Array.of_list_rev rev_epoch_losses; pixel = "-" } ]
in
PrintBox_text.output Stdio.stdout plot_loss;
Stdio.printf "\nBatch Log-loss:\n%!";
Expand All @@ -109,7 +109,7 @@ let experiment ~seed ~backend_name ~config () =
Line_plot
{
points =
Array.of_list_rev_map batch_losses ~f:Float.(fun x -> max (log 0.00003) (log x));
Array.of_list_rev_map rev_batch_losses ~f:Float.(fun x -> max (log 0.00003) (log x));
pixel = "-";
};
]
Expand All @@ -119,7 +119,7 @@ let experiment ~seed ~backend_name ~config () =
let plot_loss =
let open PrintBox_utils in
plot ~size:(120, 30) ~x_label:"step" ~y_label:"epoch log loss"
[ Line_plot { points = Array.of_list_rev_map epoch_losses ~f:Float.log; pixel = "-" } ]
[ Line_plot { points = Array.of_list_rev_map rev_epoch_losses ~f:Float.log; pixel = "-" } ]
in
PrintBox_text.output Stdio.stdout plot_loss;
Stdio.printf "\nLearning rate:\n%!";
Expand Down
16 changes: 8 additions & 8 deletions lib/train.ml
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,8 @@ type example_train_result = {
model_result : Tensor.t;
infer_callback : float array -> float array;
(** Note: infer_callback is significantly less efficient than using the model via arrayjit. *)
batch_losses : float list;
epoch_losses : float list;
rev_batch_losses : float list;
rev_epoch_losses : float list;
learning_rates : float list;
used_memory : int;
}
Expand Down Expand Up @@ -484,8 +484,8 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
let step_n, bindings = IDX.get_static_symbol bindings in
let%op input = inputs @| batch_n in
let%op expectation = outputs @| batch_n in
let batch_losses = ref [] in
let epoch_losses = ref [] in
let rev_batch_losses = ref [] in
let rev_epoch_losses = ref [] in
let learning_rates = ref [] in
let%op loss_tensor = loss_fn ~output:(model input) ~expectation in
let%op scalar_loss = (loss_tensor ++ "...|... => 0") /. !..batch_size in
Expand Down Expand Up @@ -524,7 +524,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
Backend.await grad_updates.(0).context.stream;
let batch_loss = scalar_loss.@[0] in
epoch_loss := !epoch_loss +. batch_loss;
batch_losses := batch_loss :: !batch_losses;
rev_batch_losses := batch_loss :: !rev_batch_losses;
Option.iter per_batch_callback ~f:(fun f ->
f ~at_batch:!batch_ref ~at_step:!step_ref ~learning_rate:learning_rate.@[0] ~batch_loss
~epoch_loss:!epoch_loss))
Expand All @@ -538,7 +538,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
(* Utils.capture_stdout_logs *)
update ();
learning_rates := learning_rate.@[0] :: !learning_rates;
epoch_losses := !epoch_loss :: !epoch_losses;
rev_epoch_losses := !epoch_loss :: !rev_epoch_losses;
Option.iter per_epoch_callback ~f:(fun f ->
f ~at_step:!step_ref ~at_epoch:epoch ~learning_rate:learning_rate.@[0]
~epoch_loss:!epoch_loss);
Expand Down Expand Up @@ -590,8 +590,8 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
outputs;
model_result;
infer_callback;
batch_losses = !batch_losses;
epoch_losses = !epoch_losses;
rev_batch_losses = !rev_batch_losses;
rev_epoch_losses = !rev_epoch_losses;
learning_rates = !learning_rates;
used_memory;
}
Expand Down
Loading

0 comments on commit 7d3eeba

Please sign in to comment.