diff --git a/CHANGES.md b/CHANGES.md index f88078ec..55c5c94e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -4,7 +4,8 @@ - Interface files for `Backends` and `Low_level`. - Fixed #245: tracking of used memory. -- TODO: stream-to-stream synchronization functionality, with lazy per-tensor-node synchronization. +- Stream-to-stream synchronization functionality, with lazy per-tensor-node synchronization. +- TODO: Automatic blocking on access of a host array when a scheduled `to_host` transfer has not finished. ### Changed @@ -19,7 +20,8 @@ - Got rid of `subordinal`. - Removed dependency on `core`, broke up dependency on `ppx_jane`. - Huge refactoring of backend internal interfaces and API (not repeating same code). -- TODO: Built per-tensor-node stream-to-stream synchronization into copying functions, removed obsolete blocking synchronizations. +- Built per-tensor-node stream-to-stream synchronization into copying functions. +- Re-introduced whole-device blocking synchronization, which now is just a slight optimization as it also cleans up event book-keeping. ### Fixed diff --git a/arrayjit/lib/backend_impl.ml b/arrayjit/lib/backend_impl.ml index 59f41470..4276ace4 100644 --- a/arrayjit/lib/backend_impl.ml +++ b/arrayjit/lib/backend_impl.ml @@ -101,26 +101,27 @@ struct { dev; ordinal; - latest_stream_id = -1; released = Atomic.make false; cross_stream_candidates = Hashtbl.create (module Tnode); owner_stream = Hashtbl.create (module Tnode); shared_writer_streams = Hashtbl.create (module Tnode); host_reading_streams = Hashtbl.create (module Tnode); host_writing_streams = Hashtbl.create (module Tnode); + streams = Utils.weak_create (); } - let make_stream device runner ~stream_id = - { - device; - runner; - merge_buffer = ref None; - stream_id; - allocated_buffer = None; - updating_for = Hashtbl.create (module Tnode); - updating_for_merge_buffer = None; - reader_streams = Hashtbl.create (module Tnode); - } + let make_stream device runner = + Utils.register_new device.streams ~grow_by:8 (fun stream_id -> + { + device; + runner; + merge_buffer = ref None; + stream_id; + allocated_buffer = None; + updating_for = Hashtbl.create (module Tnode); + updating_for_merge_buffer = None; + reader_streams = Hashtbl.create (module Tnode); + }) let get_name stream = [%string "%{name}:%{stream.device.ordinal#Int}:%{stream.stream_id#Int}"] diff --git a/arrayjit/lib/backend_intf.ml b/arrayjit/lib/backend_intf.ml index 1a296c08..b7f9834e 100644 --- a/arrayjit/lib/backend_intf.ml +++ b/arrayjit/lib/backend_intf.ml @@ -82,7 +82,6 @@ end type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = { dev : 'dev; ordinal : int; - mutable latest_stream_id : int; released : Utils.atomic_bool; cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t; owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Hashtbl.M(Tnode).t; @@ -92,6 +91,7 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = { (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t; host_writing_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t; + mutable streams : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Utils.weak_dynarray; } and ('buffer_ptr, 'dev, 'runner, 'event) stream_ref = { @@ -114,7 +114,6 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device = ('buffer_ptr, 'dev, 'runner, 'event) device_ref = { dev : 'dev; ordinal : int; - mutable latest_stream_id : int; released : Utils.atomic_bool; cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t; (** Freshly created arrays that might be shared across streams. The map can both grow and @@ -136,6 +135,7 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device = (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t; (** The streams that most recently have been writing to a node's on-host array. The completed events are removed opportunistically. *) + mutable streams : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Utils.weak_dynarray; (** . *) } [@@deriving sexp_of] @@ -147,7 +147,7 @@ type ('buffer_ptr, 'dev, 'runner, 'event) stream = (** Depending on backend implementations, either the currently used merge buffer, or the one most recently scheduled. Note that the pointer can be reused for nodes that fit in an already allocated buffer. *) - stream_id : int; (** An ID unique within the device. *) + stream_id : int; (** An ID unique within the device for the lifetime of the stream. *) mutable allocated_buffer : 'buffer_ptr buffer option; updating_for : 'event Hashtbl.M(Tnode).t; (* The completion event for the most recent updating (writing to) a node via this stream. *) @@ -188,7 +188,7 @@ module type Device = sig include Alloc_buffer with type buffer_ptr := buffer_ptr and type stream := stream val make_device : dev -> ordinal:int -> device - val make_stream : device -> runner -> stream_id:int -> stream + val make_stream : device -> runner -> stream val make_context : ?ctx_arrays:ctx_arrays -> stream -> context (** Returns a context without a parent. *) @@ -291,6 +291,7 @@ module type Backend_device_common = sig end module type With_buffer_retrieval_and_syncing = sig + type device type context type event @@ -318,6 +319,9 @@ module type With_buffer_retrieval_and_syncing = sig buffer, and initializes the merge buffer's streaming event. - If [into_merge_buffer=Copy], schedules copying from [src] to the merge buffer of [dst]'s stream, and updates the writer event for the merge buffer. *) + + val sync_device : device -> unit + (** Synchronizes all the streams on a device, and cleans up (removes) all associated events. *) end module type Backend = sig @@ -331,5 +335,9 @@ module type Backend = sig (** Returns the routines for the procedures included in the code batch. The returned context is downstream of all the returned routines. *) - include With_buffer_retrieval_and_syncing with type context := context and type event := event + include + With_buffer_retrieval_and_syncing + with type device := device + and type context := context + and type event := event end diff --git a/arrayjit/lib/backends.ml b/arrayjit/lib/backends.ml index e2b3a4ee..3ecf1443 100644 --- a/arrayjit/lib/backends.ml +++ b/arrayjit/lib/backends.ml @@ -22,7 +22,7 @@ let check_merge_buffer stream ~code_node = ^ ", expected by code: " ^ name code_node) module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncing) = struct - let wait_for_all ctx streams tn = + let[@landmark] wait_for_all ctx streams tn = let s = ctx.stream in Hashtbl.update_and_return streams tn ~f: @@ -31,7 +31,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin |> List.iter ~f:(fun (work_stream, e) -> if not (equal_stream work_stream s) then Backend.will_wait_for ctx e) - let wait_for_ready ~dst ~src tn = + let[@landmark] wait_for_ready ~dst ~src tn = let s = src.stream in let d = dst.stream in (* TODO: maybe it's worthwhile to clean up s.updating_for every now and then. *) @@ -39,7 +39,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin |> Option.iter ~f:(fun upd_e -> if not (equal_stream s d || Backend.is_done upd_e) then Backend.will_wait_for dst upd_e) - let update_writer_event ?e ?from s tn = + let[@landmark] update_writer_event ?e ?from s tn = let e = Option.value_or_thunk e ~default:(fun () -> Backend.all_work s) in let f l = (s, e) :: Option.value ~default:[] l in (match (from, tn) with @@ -52,13 +52,14 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin | Node tn -> if Tn.potentially_cross_stream tn then Hashtbl.update s.device.shared_writer_streams tn ~f:(fun l -> - (s, e) :: Option.value ~default:[] l); + (s, e) :: Option.value ~default:[] l) + else Hashtbl.remove s.device.shared_writer_streams tn; Hashtbl.update s.updating_for tn ~f:(fun _ -> e) | Merge_buffer tn -> (* Note: the previous event does not need to be done! *) s.updating_for_merge_buffer <- Some (tn, Some e) - let%diagn2_l_sexp from_host (ctx : Backend.context) tn = + let%track2_l_sexp[@landmark] from_host (ctx : Backend.context) tn = match (tn, Map.find ctx.ctx_arrays tn) with | { Tn.array = (lazy (Some hosted)); _ }, Some dst -> wait_for_all ctx ctx.stream.reader_streams tn; @@ -68,7 +69,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin true | _ -> false - let%diagn2_l_sexp to_host (ctx : Backend.context) (tn : Tn.t) = + let%track2_l_sexp[@landmark] to_host (ctx : Backend.context) (tn : Tn.t) = match (tn, Map.find ctx.ctx_arrays tn) with | { Tn.array = (lazy (Some hosted)); _ }, Some src -> if Tn.potentially_cross_stream tn then @@ -82,8 +83,8 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin true | _ -> false - let%diagn2_l_sexp device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : Backend.context) - ~(src : Backend.context) = + let%diagn2_l_sexp[@landmark] device_to_device (tn : Tn.t) ~into_merge_buffer + ~(dst : Backend.context) ~(src : Backend.context) = let ordinal_of ctx = ctx.stream.device.ordinal in let name_of ctx = Backend.(get_name ctx.stream) in let same_device = ordinal_of dst = ordinal_of src in @@ -115,30 +116,40 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin Backend.( device_to_device tn ~into_merge_buffer ~dst_ptr:None ~dst ~src_ptr:s_arr ~src); dst.stream.updating_for_merge_buffer <- Some (tn, None); - Task.run task; + let[@landmark] merge_task () = Task.run task in + merge_task (); update_writer_event ~from:(`Src src.stream) dst.stream @@ Merge_buffer tn; [%log "streaming into merge buffer", Tn.debug_name tn, "from", name_of src]; true) - let%track3_l_sexp sync_routine r = + let%track2_l_sexp sync_routine r = let s = r.context.stream in - let pre () = - Hashtbl.filter_mapi_inplace s.device.shared_writer_streams ~f:(fun ~key ~data -> - if Tn.potentially_cross_stream key then - if Set.mem r.inputs key then ( - let data = List.filter data ~f:(fun (_, e) -> Backend.is_done e) in - List.iter data ~f:(fun (work_stream, e) -> - if not (equal_stream work_stream s) then Backend.will_wait_for r.context e); - Some data) - else Some data - else None) + let[@landmark] pre () = + Set.iter r.inputs ~f:(fun tn -> + if Tn.potentially_cross_stream tn then + Option.iter (Hashtbl.find s.device.shared_writer_streams tn) ~f:(fun data -> + let data = List.filter data ~f:(fun (_, e) -> not (Backend.is_done e)) in + Hashtbl.set s.device.shared_writer_streams ~key:tn ~data; + List.iter data ~f:(fun (work_stream, e) -> + if not (equal_stream work_stream s) then Backend.will_wait_for r.context e)) + else Hashtbl.remove s.device.shared_writer_streams tn) (* Since merge buffers are always per-stream, no need to check r.merge_buffer_input. *) in - let post () = + let[@landmark] post () = let e = Backend.all_work s in Set.iter r.outputs ~f:(fun tn -> update_writer_event ~e s @@ Node tn) in { r with schedule = Task.(prepend ~work:pre @@ append ~work:post r.schedule) } + + let[@landmark] sync_device device = + Utils.weak_iter device.streams ~f:Backend.await; + Hashtbl.clear device.host_writing_streams; + Hashtbl.clear device.host_reading_streams; + Hashtbl.clear device.shared_writer_streams; + Utils.weak_iter device.streams ~f:(fun s -> + Hashtbl.clear s.reader_streams; + s.updating_for_merge_buffer <- None; + Hashtbl.clear s.updating_for) end let lower_assignments ?name bindings asgns = @@ -268,20 +279,20 @@ module Add_device in (Option.value_exn ~here:[%here] bindings, schedules) - let from_host ~dst_ptr ~dst hosted = + let[@landmark] from_host ~dst_ptr ~dst hosted = let work () = host_to_buffer hosted ~dst:dst_ptr in (* TODO: pass description to from_host. *) schedule_task dst.stream (Task.Task { context_lifetime = dst; description = "from_host on " ^ get_name dst.stream; work }) - let to_host ~src_ptr ~src hosted = + let[@landmark] to_host ~src_ptr ~src hosted = let work () = buffer_to_host hosted ~src:src_ptr in (* TODO: pass description to to_host. *) schedule_task src.stream (Task.Task { context_lifetime = src; description = "to_host on " ^ get_name src.stream; work }) - let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src = + let[@landmark] device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src = let s = dst.stream in let size_in_bytes = Tnode.size_in_bytes tn in let work = @@ -468,7 +479,7 @@ let reinitialize (module Backend : Backend) config = Stdlib.Gc.full_major (); Backend.initialize config) -let%track3_sexp finalize (type buffer_ptr dev runner event) +let[@landmark] finalize (type buffer_ptr dev runner event) (module Backend : Backend with type buffer_ptr = buffer_ptr and type dev = dev diff --git a/arrayjit/lib/cuda_backend.cudajit.ml b/arrayjit/lib/cuda_backend.cudajit.ml index dd1c174d..f5a81fab 100644 --- a/arrayjit/lib/cuda_backend.cudajit.ml +++ b/arrayjit/lib/cuda_backend.cudajit.ml @@ -13,11 +13,11 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime let () = Cu.cuda_call_hook := Some - (fun ~message ~status -> + (fun ~message:_message ~status:_status -> [%debug_l_sexp [%log5_block - message; - if not @@ Cu.is_success status then [%log (status : Cu.result)]]]) + _message; + if not @@ Cu.is_success _status then [%log (_status : Cu.result)]]]) let _suspended () = Cu.cuda_call_hook := Some (fun ~message ~status:_ -> Stdlib.Printf.printf "CUDA %s\n" message) @@ -149,11 +149,10 @@ let%track3_sexp get_device ~(ordinal : int) : device = if Atomic.get result.released then default () else result let%track3_sexp new_stream (device : device) : stream = - device.latest_stream_id <- device.latest_stream_id + 1; (* Strange that we need ctx_set_current even with a single device! *) set_ctx device.dev.primary_context; let cu_stream = Cu.Stream.create ~non_blocking:true () in - make_stream device cu_stream ~stream_id:device.latest_stream_id + make_stream device cu_stream let cuda_properties = let cache = @@ -173,24 +172,24 @@ let suggested_num_streams device = | For_parallel_copying -> 1 + (cuda_properties device).async_engine_count | Most_parallel_streams -> (cuda_properties device).multiprocessor_count -let await stream : unit = +let[@landmark] await stream : unit = set_ctx stream.device.dev.primary_context; Cu.Stream.synchronize stream.runner; Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ()) let is_idle stream = Cu.Stream.is_ready stream.runner -let from_host ~dst_ptr ~dst hosted = +let[@landmark] from_host ~dst_ptr ~dst hosted = set_ctx @@ ctx_of dst; let f src = Cu.Stream.memcpy_H_to_D ~dst:dst_ptr ~src dst.stream.runner in Ndarray.map { f } hosted -let to_host ~src_ptr ~src hosted = +let[@landmark] to_host ~src_ptr ~src hosted = set_ctx @@ ctx_of src; let f dst = Cu.Stream.memcpy_D_to_H ~dst ~src:src_ptr src.stream.runner in Ndarray.map { f } hosted -let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src = +let[@landmark] device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src = let dev = dst.stream.device in let same_device = dev.ordinal = src.stream.device.ordinal in let size_in_bytes = Tn.size_in_bytes tn in @@ -248,6 +247,8 @@ let%diagn2_sexp cuda_to_ptx ~name cu_src = let options = "--use_fast_math" :: (if Utils.with_runtime_debug () then [ "--device-debug" ] else []) in + (* FIXME: every now and then the compilation crashes because the options are garbled. *) + (* Stdio.printf "PTX options %s\n%!" @@ String.concat ~sep:", " options; *) let ptx = Cu.Nvrtc.compile_to_ptx ~cu_src ~name:name_cu ~options ~with_debug in if Utils.settings.output_debug_files_in_build_directory then ( let oc = Out_channel.open_text @@ Utils.build_file @@ name ^ ".ptx" in diff --git a/arrayjit/lib/dune b/arrayjit/lib/dune index 1e1b9018..1646ca98 100644 --- a/arrayjit/lib/dune +++ b/arrayjit/lib/dune @@ -12,6 +12,7 @@ ctypes ctypes.foreign saturn_lockfree + landmarks (select gcc_backend.ml from @@ -31,6 +32,7 @@ ppx_sexp_conv ppx_string ppx_variants_conv + landmarks-ppx ppx_minidebug)) (modules utils diff --git a/arrayjit/lib/schedulers.ml b/arrayjit/lib/schedulers.ml index d4e7c128..b38888fb 100644 --- a/arrayjit/lib/schedulers.ml +++ b/arrayjit/lib/schedulers.ml @@ -137,48 +137,62 @@ module Multicore (Backend : For_add_scheduler) : schedule_task stream @@ Task { context_lifetime = (); description = "clock tick"; work }; { stream_state; target_clock = stream_state.schedule_clock } - let%track3_l_sexp spinup_stream ~stream_id : stream = - Int.incr global_run_no; - let state = - { - keep_spinning = true; - stream_error = None; - queue = Queue.create ~size_exponent:12; - mut = Mut.create (); - is_ready = false; - host_wait_for_idle = Stdlib.Condition.create (); - dev_wait_for_work = Stdlib.Condition.create (); - clock_tick = Stdlib.Condition.create (); - schedule_clock = 0; - run_clock = Atomic.make 0; - stream_id; - } - in - let%track3_l_sexp worker (() : unit) : unit = - assert (not @@ Domain.is_main_domain ()); - try - while state.keep_spinning do - match Queue.pop_opt state.queue with - | None -> - Mut.lock state.mut; - state.is_ready <- true; - Stdlib.Condition.broadcast state.host_wait_for_idle; - while is_dev_queue_empty state && state.keep_spinning do - Stdlib.Condition.wait state.dev_wait_for_work state.mut - done; - state.is_ready <- false; - Mut.unlock state.mut - | Some task -> Task.run task - done - with e -> - state.stream_error <- Some e; - state.keep_spinning <- false; - [%log1 "stream", (stream_id : int), "exception", Exn.to_string e]; - (* TODO: we risk raising this error multiple times because await and schedule_task raise - stream_error. But this is fine if we assume all exceptions are fatal. *) - raise e + let%track3_l_sexp spinup_stream () : stream = + let create stream_id = + Int.incr global_run_no; + let state = + { + keep_spinning = true; + stream_error = None; + queue = Queue.create ~size_exponent:12; + mut = Mut.create (); + is_ready = false; + host_wait_for_idle = Stdlib.Condition.create (); + dev_wait_for_work = Stdlib.Condition.create (); + clock_tick = Stdlib.Condition.create (); + schedule_clock = 0; + run_clock = Atomic.make 0; + stream_id; + } + in + let%track3_l_sexp worker (() : unit) : unit = + assert (not @@ Domain.is_main_domain ()); + try + while state.keep_spinning do + match Queue.pop_opt state.queue with + | None -> + Mut.lock state.mut; + state.is_ready <- true; + Stdlib.Condition.broadcast state.host_wait_for_idle; + while is_dev_queue_empty state && state.keep_spinning do + Stdlib.Condition.wait state.dev_wait_for_work state.mut + done; + state.is_ready <- false; + Mut.unlock state.mut + | Some task -> Task.run task + done + with e -> + state.stream_error <- Some e; + state.keep_spinning <- false; + [%log1 "stream", (stream_id : int), "exception", Exn.to_string e]; + (* TODO: we risk raising this error multiple times because await and schedule_task raise + stream_error. But this is fine if we assume all exceptions are fatal. *) + raise e + in + { state; domain = Domain.spawn worker } in - make_stream device { state; domain = Domain.spawn worker } ~stream_id + Utils.register_new device.streams ~grow_by:8 (fun stream_id -> + let runner = create stream_id in + { + device; + runner; + merge_buffer = ref None; + stream_id; + allocated_buffer = None; + updating_for = Hashtbl.create (module Tnode); + updating_for_merge_buffer = None; + reader_streams = Hashtbl.create (module Tnode); + }) module Dynarr = Stdlib.Dynarray @@ -198,12 +212,9 @@ module Multicore (Backend : For_add_scheduler) : invalid_arg [%string "Multicore_scheduler.get_device %{ordinal#Int}: only device 0 exists"]; device - let latest_stream_id = ref (-1) - let new_stream _device = assert (Domain.is_main_domain ()); - Int.incr latest_stream_id; - let stream = spinup_stream ~stream_id:!latest_stream_id in + let stream = spinup_stream () in Stdlib.Gc.finalise cleanup_stream stream; stream @@ -252,12 +263,7 @@ module Sync (Backend : For_add_scheduler) = struct let num_devices () = 1 let suggested_num_streams _ = !sync_suggested_num_streams let get_used_memory _ = Backend.get_used_memory () - let latest_stram_id = ref (-1) - - let new_stream device = - Int.incr latest_stram_id; - make_stream device () ~stream_id:!latest_stram_id - + let new_stream device = make_stream device () let all_work _stream = () let is_idle _stream = true let await _stream = () diff --git a/arrayjit/lib/utils.ml b/arrayjit/lib/utils.ml index 6c80a22b..6be79c56 100644 --- a/arrayjit/lib/utils.ml +++ b/arrayjit/lib/utils.ml @@ -634,3 +634,30 @@ let capture_stdout_logs ?(never_skip = false) arg = advance_captured_logs := None; captured_log_processors := []); result) + +type 'a weak_dynarray = 'a Stdlib.Weak.t ref + +let weak_create () : 'a weak_dynarray = ref @@ Stdlib.Weak.create 0 + +let sexp_of_weak_dynarray sexp_of_elem arr = + sexp_of_array (sexp_of_option sexp_of_elem) Stdlib.Weak.(Array.init (length !arr) ~f:(get !arr)) + +let register_new (arr : 'a weak_dynarray) ?(grow_by = 1) create = + let module W = Stdlib.Weak in + let old = !arr in + let pos = ref 0 in + while !pos < W.length old && W.check old !pos do + Int.incr pos + done; + if !pos >= W.length old then ( + arr := Stdlib.Weak.create (W.length old + grow_by); + Stdlib.Weak.blit old 0 !arr 0 (Stdlib.Weak.length old)); + let v = create !pos in + W.set !arr !pos (Some v); + v + +let weak_iter (arr : 'a weak_dynarray) ~f = + let module W = Stdlib.Weak in + for i = 0 to W.length !arr - 1 do + Option.iter (W.get !arr i) ~f + done diff --git a/lib/dune b/lib/dune index 6f96b34b..de27415a 100644 --- a/lib/dune +++ b/lib/dune @@ -11,6 +11,7 @@ sexplib num str + landmarks ; mem_usage ppx_minidebug.runtime arrayjit) @@ -24,6 +25,7 @@ ppx_string ppx_variants_conv ppx_ocannl + landmarks-ppx ppx_minidebug)) (modules PrintBox_utils row shape tensor operation train nn_blocks) (modes byte native) diff --git a/lib/train.ml b/lib/train.ml index bd3754fe..8838edcb 100644 --- a/lib/train.ml +++ b/lib/train.ml @@ -540,8 +540,8 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init @@ Backend.get_debug_info s) in if per_epoch_debug_streams then _debug_at "before sync"; - (* TODO: there should be nothing pending left to sync. *) - Array.iter streams ~f:Backend.await + (* TODO: there should be nothing pending left to sync. And it offers only a slight speed up. *) + Array.iter devices ~f:Backend.(fun d -> sync_device d) (* This is now cleaned up by await. *) (* if per_epoch_debug_streams then _debug_at "after sync" *) done; diff --git a/todo.md b/todo.md index 1550f524..bfda29f9 100644 --- a/todo.md +++ b/todo.md @@ -3,4 +3,5 @@ (B) remove syncing from the data parallel algo: stream-to-stream syncing is now automatic {cm:2024-11-23} (A) cuda backend crashes in bin/moons_benchmark {cm:2024-11-22} (B) figure out why cuda backend parallelism slows down in later epochs {cm:2024-11-25} -(A) Ensure that reading from host on CPU performs required synchronization \ No newline at end of file +(A) Ensure that reading from host on CPU performs required synchronization +clean up event hashtables when a stream or device gets synchronized {cm:2024-12-03} \ No newline at end of file