Skip to content

Commit

Permalink
Remove syncing from the data-parallel algo
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Nov 23, 2024
1 parent 9dd686b commit 5ca3cff
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 17 deletions.
25 changes: 9 additions & 16 deletions lib/train.ml
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,14 @@ let%track3_sexp sync_run ?looping (type buffer_ptr dev runner event)

module Lazy = Utils.Lazy

(** Performs one optimization step, potentially in parallel (if [grad_updates] are compiled for
different devices). All jitted code must have the same bindings. Iterates over bindings with
ranges, calling one of [grad_updates] in a round-robin fashion, and performs the following
synchronization each time all [grad_updates] have been called:
1. merges all gradients into the device of [grad_updates.(0)], 2. calls [sgd_update], 3. copies
all parameters from the [grad_updates.(0)] device to the other devices, if needed, 4. calls
[post_sync] with the number of devices synced since the previous sync.
(** Performs one optimization step, potentially in parallel (if [grad_updates] are linked with
different streams or devices). All jitted code must have the same bindings. Iterates over
bindings with ranges, calling one of [grad_updates] in a round-robin fashion, and performs the
following synchronization each time all [grad_updates] have been called:
- merges all gradients into the device of [grad_updates.(0)],
- calls [sgd_update],
- copies all parameters from the [grad_updates.(0)] device to the other devices, if needed,
- calls [post_sync] with the number of devices synced since the previous sync.
All and only bindings with associated ranges are iterated, with the binding's initial value
lost. Bindings without ranges remain at their initial values. *)
Expand Down Expand Up @@ -378,20 +378,16 @@ let%track3_sexp parallel_update (type buffer_ptr dev runner event)
let into_merge_buffer = if copy_to_merge then BT.Copy else BT.Streaming in
(* Since each device has its own queue, we can iterate over devices in the outer loop. *)
let merge_grads ~(from : int) ~(to_ : int) : unit =
(* Synchronize the source since we compute on the destionation. *)
Backend.(await ctxs.(from).stream);
Array.iteri all_params ~f:(fun i p ->
let grad_merge =
Option.value_exn ~here:[%here] ~message:(Tn.debug_name p.value) grad_merges_to.(to_).(i)
in
(* NOTE: we no longer have to to pass [grad_merge.context] as [dst]. *)
assert (
Backend.device_to_device (Option.value_exn ~here:[%here] p.diff).grad ~into_merge_buffer
~dst:ctxs.(to_) ~src:ctxs.(from));
(Task.run grad_merge.schedule : unit))
in
let merge_loss ~src =
(* NOTE: we no longer have to to pass [loss_merge.context] as [dst]. *)
assert (
Backend.device_to_device updaten.loss.value ~into_merge_buffer ~dst:sgd_update.context ~src);
Task.run loss_merge.schedule
Expand All @@ -400,12 +396,9 @@ let%track3_sexp parallel_update (type buffer_ptr dev runner event)
let needed_on_host = ref @@ Set.empty (module Tn) in
let%track3_sexp sync (devices_to_sync : int) : unit =
Arrayjit.Utils.parallel_merge merge_grads devices_to_sync;
(* We need to wait, because copying happens on other devices. *)
Array.iteri ctxs ~f:(fun i src -> if i <> 0 then Backend.(await src.stream));
Task.run sgd_update.schedule;
Array.iteri ctxs ~f:(fun i src -> if i <> 0 then merge_loss ~src);
Set.iter !needed_on_host ~f:(fun p -> assert (Backend.to_host sgd_update.context p));
Backend.(await sgd_update.context.stream);
(* We will need to update params on all devices! Not only the ones that computed gradients. *)
for to_ = 1 to num_devices - 1 do
Array.iter all_params ~f:(fun p ->
Expand Down Expand Up @@ -517,7 +510,6 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
assert (Backend.to_host sgd_update.context learning_rate.value);
(* scalar_loss is not in the sgd_update context. *)
assert (Backend.to_host grad_updates.(0).context scalar_loss.value);
Backend.(await grad_updates.(0).context.stream);
let batch_loss = scalar_loss.@[0] in
epoch_loss := !epoch_loss +. batch_loss;
batch_losses := batch_loss :: !batch_losses;
Expand Down Expand Up @@ -561,6 +553,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
assert (Backend.from_host routine.context infer.value);
run routine;
assert (Backend.to_host routine.context model_result.value);
(* TODO: get_values itself should sync with host writing events. *)
Backend.(await routine.context.stream);
Tensor.get_values model_result
in
Expand Down
2 changes: 1 addition & 1 deletion todo.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file is for tasks with a smaller granularity than issues, typically immediate tasks.
(B) bin/moons_benchmark with the cc backend crashes with half-prec overflow
(B) remove syncing from the data parallel algo: stream-to-stream syncing is now automatic
(B) remove syncing from the data parallel algo: stream-to-stream syncing is now automatic {cm:2024-11-23}
(A) cuda backend crashes in bin/moons_benchmark {cm:2024-11-22}
(B) figure out why cuda backend parallelism slows down in later epochs

0 comments on commit 5ca3cff

Please sign in to comment.