diff --git a/lib/train.ml b/lib/train.ml index f6336cea..b27f7ab2 100644 --- a/lib/train.ml +++ b/lib/train.ml @@ -314,14 +314,14 @@ let%track3_sexp sync_run ?looping (type buffer_ptr dev runner event) module Lazy = Utils.Lazy -(** Performs one optimization step, potentially in parallel (if [grad_updates] are compiled for - different devices). All jitted code must have the same bindings. Iterates over bindings with - ranges, calling one of [grad_updates] in a round-robin fashion, and performs the following - synchronization each time all [grad_updates] have been called: - - 1. merges all gradients into the device of [grad_updates.(0)], 2. calls [sgd_update], 3. copies - all parameters from the [grad_updates.(0)] device to the other devices, if needed, 4. calls - [post_sync] with the number of devices synced since the previous sync. +(** Performs one optimization step, potentially in parallel (if [grad_updates] are linked with + different streams or devices). All jitted code must have the same bindings. Iterates over + bindings with ranges, calling one of [grad_updates] in a round-robin fashion, and performs the + following synchronization each time all [grad_updates] have been called: + - merges all gradients into the device of [grad_updates.(0)], + - calls [sgd_update], + - copies all parameters from the [grad_updates.(0)] device to the other devices, if needed, + - calls [post_sync] with the number of devices synced since the previous sync. All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values. *) @@ -378,20 +378,16 @@ let%track3_sexp parallel_update (type buffer_ptr dev runner event) let into_merge_buffer = if copy_to_merge then BT.Copy else BT.Streaming in (* Since each device has its own queue, we can iterate over devices in the outer loop. *) let merge_grads ~(from : int) ~(to_ : int) : unit = - (* Synchronize the source since we compute on the destionation. *) - Backend.(await ctxs.(from).stream); Array.iteri all_params ~f:(fun i p -> let grad_merge = Option.value_exn ~here:[%here] ~message:(Tn.debug_name p.value) grad_merges_to.(to_).(i) in - (* NOTE: we no longer have to to pass [grad_merge.context] as [dst]. *) assert ( Backend.device_to_device (Option.value_exn ~here:[%here] p.diff).grad ~into_merge_buffer ~dst:ctxs.(to_) ~src:ctxs.(from)); (Task.run grad_merge.schedule : unit)) in let merge_loss ~src = - (* NOTE: we no longer have to to pass [loss_merge.context] as [dst]. *) assert ( Backend.device_to_device updaten.loss.value ~into_merge_buffer ~dst:sgd_update.context ~src); Task.run loss_merge.schedule @@ -400,12 +396,9 @@ let%track3_sexp parallel_update (type buffer_ptr dev runner event) let needed_on_host = ref @@ Set.empty (module Tn) in let%track3_sexp sync (devices_to_sync : int) : unit = Arrayjit.Utils.parallel_merge merge_grads devices_to_sync; - (* We need to wait, because copying happens on other devices. *) - Array.iteri ctxs ~f:(fun i src -> if i <> 0 then Backend.(await src.stream)); Task.run sgd_update.schedule; Array.iteri ctxs ~f:(fun i src -> if i <> 0 then merge_loss ~src); Set.iter !needed_on_host ~f:(fun p -> assert (Backend.to_host sgd_update.context p)); - Backend.(await sgd_update.context.stream); (* We will need to update params on all devices! Not only the ones that computed gradients. *) for to_ = 1 to num_devices - 1 do Array.iter all_params ~f:(fun p -> @@ -517,7 +510,6 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init assert (Backend.to_host sgd_update.context learning_rate.value); (* scalar_loss is not in the sgd_update context. *) assert (Backend.to_host grad_updates.(0).context scalar_loss.value); - Backend.(await grad_updates.(0).context.stream); let batch_loss = scalar_loss.@[0] in epoch_loss := !epoch_loss +. batch_loss; batch_losses := batch_loss :: !batch_losses; @@ -561,6 +553,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init assert (Backend.from_host routine.context infer.value); run routine; assert (Backend.to_host routine.context model_result.value); + (* TODO: get_values itself should sync with host writing events. *) Backend.(await routine.context.stream); Tensor.get_values model_result in diff --git a/todo.md b/todo.md index 20b3d7a8..511e9ad8 100644 --- a/todo.md +++ b/todo.md @@ -1,5 +1,5 @@ # This file is for tasks with a smaller granularity than issues, typically immediate tasks. (B) bin/moons_benchmark with the cc backend crashes with half-prec overflow -(B) remove syncing from the data parallel algo: stream-to-stream syncing is now automatic +(B) remove syncing from the data parallel algo: stream-to-stream syncing is now automatic {cm:2024-11-23} (A) cuda backend crashes in bin/moons_benchmark {cm:2024-11-22} (B) figure out why cuda backend parallelism slows down in later epochs \ No newline at end of file