Skip to content

Commit

Permalink
Synchronize all devices of a stream, with cleanup; landmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Dec 3, 2024
1 parent 324bfc2 commit 772bea0
Show file tree
Hide file tree
Showing 11 changed files with 168 additions and 107 deletions.
6 changes: 4 additions & 2 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

- Interface files for `Backends` and `Low_level`.
- Fixed #245: tracking of used memory.
- TODO: stream-to-stream synchronization functionality, with lazy per-tensor-node synchronization.
- Stream-to-stream synchronization functionality, with lazy per-tensor-node synchronization.
- TODO: Automatic blocking on access of a host array when a scheduled `to_host` transfer has not finished.

### Changed

Expand All @@ -19,7 +20,8 @@
- Got rid of `subordinal`.
- Removed dependency on `core`, broke up dependency on `ppx_jane`.
- Huge refactoring of backend internal interfaces and API (not repeating same code).
- TODO: Built per-tensor-node stream-to-stream synchronization into copying functions, removed obsolete blocking synchronizations.
- Built per-tensor-node stream-to-stream synchronization into copying functions.
- Re-introduced whole-device blocking synchronization, which now is just a slight optimization as it also cleans up event book-keeping.

### Fixed

Expand Down
25 changes: 13 additions & 12 deletions arrayjit/lib/backend_impl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -101,26 +101,27 @@ struct
{
dev;
ordinal;
latest_stream_id = -1;
released = Atomic.make false;
cross_stream_candidates = Hashtbl.create (module Tnode);
owner_stream = Hashtbl.create (module Tnode);
shared_writer_streams = Hashtbl.create (module Tnode);
host_reading_streams = Hashtbl.create (module Tnode);
host_writing_streams = Hashtbl.create (module Tnode);
streams = Utils.weak_create ();
}

let make_stream device runner ~stream_id =
{
device;
runner;
merge_buffer = ref None;
stream_id;
allocated_buffer = None;
updating_for = Hashtbl.create (module Tnode);
updating_for_merge_buffer = None;
reader_streams = Hashtbl.create (module Tnode);
}
let make_stream device runner =
Utils.register_new device.streams ~grow_by:8 (fun stream_id ->
{
device;
runner;
merge_buffer = ref None;
stream_id;
allocated_buffer = None;
updating_for = Hashtbl.create (module Tnode);
updating_for_merge_buffer = None;
reader_streams = Hashtbl.create (module Tnode);
})

let get_name stream = [%string "%{name}:%{stream.device.ordinal#Int}:%{stream.stream_id#Int}"]

Expand Down
18 changes: 13 additions & 5 deletions arrayjit/lib/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ end
type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
dev : 'dev;
ordinal : int;
mutable latest_stream_id : int;
released : Utils.atomic_bool;
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Hashtbl.M(Tnode).t;
Expand All @@ -92,6 +91,7 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
host_writing_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
mutable streams : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Utils.weak_dynarray;
}

and ('buffer_ptr, 'dev, 'runner, 'event) stream_ref = {
Expand All @@ -114,7 +114,6 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device =
('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
dev : 'dev;
ordinal : int;
mutable latest_stream_id : int;
released : Utils.atomic_bool;
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
(** Freshly created arrays that might be shared across streams. The map can both grow and
Expand All @@ -136,6 +135,7 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device =
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been writing to a node's on-host array. The completed
events are removed opportunistically. *)
mutable streams : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Utils.weak_dynarray; (** . *)
}
[@@deriving sexp_of]

Expand All @@ -147,7 +147,7 @@ type ('buffer_ptr, 'dev, 'runner, 'event) stream =
(** Depending on backend implementations, either the currently used merge buffer, or the one
most recently scheduled. Note that the pointer can be reused for nodes that fit in an
already allocated buffer. *)
stream_id : int; (** An ID unique within the device. *)
stream_id : int; (** An ID unique within the device for the lifetime of the stream. *)
mutable allocated_buffer : 'buffer_ptr buffer option;
updating_for : 'event Hashtbl.M(Tnode).t;
(* The completion event for the most recent updating (writing to) a node via this stream. *)
Expand Down Expand Up @@ -188,7 +188,7 @@ module type Device = sig
include Alloc_buffer with type buffer_ptr := buffer_ptr and type stream := stream

val make_device : dev -> ordinal:int -> device
val make_stream : device -> runner -> stream_id:int -> stream
val make_stream : device -> runner -> stream

val make_context : ?ctx_arrays:ctx_arrays -> stream -> context
(** Returns a context without a parent. *)
Expand Down Expand Up @@ -291,6 +291,7 @@ module type Backend_device_common = sig
end

module type With_buffer_retrieval_and_syncing = sig
type device
type context
type event

Expand Down Expand Up @@ -318,6 +319,9 @@ module type With_buffer_retrieval_and_syncing = sig
buffer, and initializes the merge buffer's streaming event.
- If [into_merge_buffer=Copy], schedules copying from [src] to the merge buffer of [dst]'s
stream, and updates the writer event for the merge buffer. *)

val sync_device : device -> unit
(** Synchronizes all the streams on a device, and cleans up (removes) all associated events. *)
end

module type Backend = sig
Expand All @@ -331,5 +335,9 @@ module type Backend = sig
(** Returns the routines for the procedures included in the code batch. The returned context is
downstream of all the returned routines. *)

include With_buffer_retrieval_and_syncing with type context := context and type event := event
include
With_buffer_retrieval_and_syncing
with type device := device
and type context := context
and type event := event
end
61 changes: 36 additions & 25 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ let check_merge_buffer stream ~code_node =
^ ", expected by code: " ^ name code_node)

module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncing) = struct
let wait_for_all ctx streams tn =
let[@landmark] wait_for_all ctx streams tn =
let s = ctx.stream in
Hashtbl.update_and_return streams tn
~f:
Expand All @@ -31,15 +31,15 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
|> List.iter ~f:(fun (work_stream, e) ->
if not (equal_stream work_stream s) then Backend.will_wait_for ctx e)

let wait_for_ready ~dst ~src tn =
let[@landmark] wait_for_ready ~dst ~src tn =
let s = src.stream in
let d = dst.stream in
(* TODO: maybe it's worthwhile to clean up s.updating_for every now and then. *)
Hashtbl.find s.updating_for tn
|> Option.iter ~f:(fun upd_e ->
if not (equal_stream s d || Backend.is_done upd_e) then Backend.will_wait_for dst upd_e)

let update_writer_event ?e ?from s tn =
let[@landmark] update_writer_event ?e ?from s tn =
let e = Option.value_or_thunk e ~default:(fun () -> Backend.all_work s) in
let f l = (s, e) :: Option.value ~default:[] l in
(match (from, tn) with
Expand All @@ -52,13 +52,14 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
| Node tn ->
if Tn.potentially_cross_stream tn then
Hashtbl.update s.device.shared_writer_streams tn ~f:(fun l ->
(s, e) :: Option.value ~default:[] l);
(s, e) :: Option.value ~default:[] l)
else Hashtbl.remove s.device.shared_writer_streams tn;
Hashtbl.update s.updating_for tn ~f:(fun _ -> e)
| Merge_buffer tn ->
(* Note: the previous event does not need to be done! *)
s.updating_for_merge_buffer <- Some (tn, Some e)

let%diagn2_l_sexp from_host (ctx : Backend.context) tn =
let%track2_l_sexp[@landmark] from_host (ctx : Backend.context) tn =
match (tn, Map.find ctx.ctx_arrays tn) with
| { Tn.array = (lazy (Some hosted)); _ }, Some dst ->
wait_for_all ctx ctx.stream.reader_streams tn;
Expand All @@ -68,7 +69,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
true
| _ -> false

let%diagn2_l_sexp to_host (ctx : Backend.context) (tn : Tn.t) =
let%track2_l_sexp[@landmark] to_host (ctx : Backend.context) (tn : Tn.t) =
match (tn, Map.find ctx.ctx_arrays tn) with
| { Tn.array = (lazy (Some hosted)); _ }, Some src ->
if Tn.potentially_cross_stream tn then
Expand All @@ -82,8 +83,8 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
true
| _ -> false

let%diagn2_l_sexp device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : Backend.context)
~(src : Backend.context) =
let%diagn2_l_sexp[@landmark] device_to_device (tn : Tn.t) ~into_merge_buffer
~(dst : Backend.context) ~(src : Backend.context) =
let ordinal_of ctx = ctx.stream.device.ordinal in
let name_of ctx = Backend.(get_name ctx.stream) in
let same_device = ordinal_of dst = ordinal_of src in
Expand Down Expand Up @@ -115,30 +116,40 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
Backend.(
device_to_device tn ~into_merge_buffer ~dst_ptr:None ~dst ~src_ptr:s_arr ~src);
dst.stream.updating_for_merge_buffer <- Some (tn, None);
Task.run task;
let[@landmark] merge_task () = Task.run task in
merge_task ();
update_writer_event ~from:(`Src src.stream) dst.stream @@ Merge_buffer tn;
[%log "streaming into merge buffer", Tn.debug_name tn, "from", name_of src];
true)

let%track3_l_sexp sync_routine r =
let%track2_l_sexp sync_routine r =
let s = r.context.stream in
let pre () =
Hashtbl.filter_mapi_inplace s.device.shared_writer_streams ~f:(fun ~key ~data ->
if Tn.potentially_cross_stream key then
if Set.mem r.inputs key then (
let data = List.filter data ~f:(fun (_, e) -> Backend.is_done e) in
List.iter data ~f:(fun (work_stream, e) ->
if not (equal_stream work_stream s) then Backend.will_wait_for r.context e);
Some data)
else Some data
else None)
let[@landmark] pre () =
Set.iter r.inputs ~f:(fun tn ->
if Tn.potentially_cross_stream tn then
Option.iter (Hashtbl.find s.device.shared_writer_streams tn) ~f:(fun data ->
let data = List.filter data ~f:(fun (_, e) -> not (Backend.is_done e)) in
Hashtbl.set s.device.shared_writer_streams ~key:tn ~data;
List.iter data ~f:(fun (work_stream, e) ->
if not (equal_stream work_stream s) then Backend.will_wait_for r.context e))
else Hashtbl.remove s.device.shared_writer_streams tn)
(* Since merge buffers are always per-stream, no need to check r.merge_buffer_input. *)
in
let post () =
let[@landmark] post () =
let e = Backend.all_work s in
Set.iter r.outputs ~f:(fun tn -> update_writer_event ~e s @@ Node tn)
in
{ r with schedule = Task.(prepend ~work:pre @@ append ~work:post r.schedule) }

let[@landmark] sync_device device =
Utils.weak_iter device.streams ~f:Backend.await;
Hashtbl.clear device.host_writing_streams;
Hashtbl.clear device.host_reading_streams;
Hashtbl.clear device.shared_writer_streams;
Utils.weak_iter device.streams ~f:(fun s ->
Hashtbl.clear s.reader_streams;
s.updating_for_merge_buffer <- None;
Hashtbl.clear s.updating_for)
end

let lower_assignments ?name bindings asgns =
Expand Down Expand Up @@ -268,20 +279,20 @@ module Add_device
in
(Option.value_exn ~here:[%here] bindings, schedules)

let from_host ~dst_ptr ~dst hosted =
let[@landmark] from_host ~dst_ptr ~dst hosted =
let work () = host_to_buffer hosted ~dst:dst_ptr in
(* TODO: pass description to from_host. *)
schedule_task dst.stream
(Task.Task
{ context_lifetime = dst; description = "from_host on " ^ get_name dst.stream; work })

let to_host ~src_ptr ~src hosted =
let[@landmark] to_host ~src_ptr ~src hosted =
let work () = buffer_to_host hosted ~src:src_ptr in
(* TODO: pass description to to_host. *)
schedule_task src.stream
(Task.Task { context_lifetime = src; description = "to_host on " ^ get_name src.stream; work })

let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
let[@landmark] device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
let s = dst.stream in
let size_in_bytes = Tnode.size_in_bytes tn in
let work =
Expand Down Expand Up @@ -468,7 +479,7 @@ let reinitialize (module Backend : Backend) config =
Stdlib.Gc.full_major ();
Backend.initialize config)

let%track3_sexp finalize (type buffer_ptr dev runner event)
let[@landmark] finalize (type buffer_ptr dev runner event)
(module Backend : Backend
with type buffer_ptr = buffer_ptr
and type dev = dev
Expand Down
19 changes: 10 additions & 9 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
let () =
Cu.cuda_call_hook :=
Some
(fun ~message ~status ->
(fun ~message:_message ~status:_status ->
[%debug_l_sexp
[%log5_block
message;
if not @@ Cu.is_success status then [%log (status : Cu.result)]]])
_message;
if not @@ Cu.is_success _status then [%log (_status : Cu.result)]]])

let _suspended () =
Cu.cuda_call_hook := Some (fun ~message ~status:_ -> Stdlib.Printf.printf "CUDA %s\n" message)
Expand Down Expand Up @@ -149,11 +149,10 @@ let%track3_sexp get_device ~(ordinal : int) : device =
if Atomic.get result.released then default () else result

let%track3_sexp new_stream (device : device) : stream =
device.latest_stream_id <- device.latest_stream_id + 1;
(* Strange that we need ctx_set_current even with a single device! *)
set_ctx device.dev.primary_context;
let cu_stream = Cu.Stream.create ~non_blocking:true () in
make_stream device cu_stream ~stream_id:device.latest_stream_id
make_stream device cu_stream

let cuda_properties =
let cache =
Expand All @@ -173,24 +172,24 @@ let suggested_num_streams device =
| For_parallel_copying -> 1 + (cuda_properties device).async_engine_count
| Most_parallel_streams -> (cuda_properties device).multiprocessor_count

let await stream : unit =
let[@landmark] await stream : unit =
set_ctx stream.device.dev.primary_context;
Cu.Stream.synchronize stream.runner;
Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ())

let is_idle stream = Cu.Stream.is_ready stream.runner

let from_host ~dst_ptr ~dst hosted =
let[@landmark] from_host ~dst_ptr ~dst hosted =
set_ctx @@ ctx_of dst;
let f src = Cu.Stream.memcpy_H_to_D ~dst:dst_ptr ~src dst.stream.runner in
Ndarray.map { f } hosted

let to_host ~src_ptr ~src hosted =
let[@landmark] to_host ~src_ptr ~src hosted =
set_ctx @@ ctx_of src;
let f dst = Cu.Stream.memcpy_D_to_H ~dst ~src:src_ptr src.stream.runner in
Ndarray.map { f } hosted

let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
let[@landmark] device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
let dev = dst.stream.device in
let same_device = dev.ordinal = src.stream.device.ordinal in
let size_in_bytes = Tn.size_in_bytes tn in
Expand Down Expand Up @@ -248,6 +247,8 @@ let%diagn2_sexp cuda_to_ptx ~name cu_src =
let options =
"--use_fast_math" :: (if Utils.with_runtime_debug () then [ "--device-debug" ] else [])
in
(* FIXME: every now and then the compilation crashes because the options are garbled. *)
(* Stdio.printf "PTX options %s\n%!" @@ String.concat ~sep:", " options; *)
let ptx = Cu.Nvrtc.compile_to_ptx ~cu_src ~name:name_cu ~options ~with_debug in
if Utils.settings.output_debug_files_in_build_directory then (
let oc = Out_channel.open_text @@ Utils.build_file @@ name ^ ".ptx" in
Expand Down
2 changes: 2 additions & 0 deletions arrayjit/lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ctypes
ctypes.foreign
saturn_lockfree
landmarks
(select
gcc_backend.ml
from
Expand All @@ -31,6 +32,7 @@
ppx_sexp_conv
ppx_string
ppx_variants_conv
landmarks-ppx
ppx_minidebug))
(modules
utils
Expand Down
Loading

0 comments on commit 772bea0

Please sign in to comment.