Skip to content

Commit

Permalink
Fix auto transfer from/to host in presence of multiple devices
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Jan 1, 2025
1 parent a91751b commit 7d333cd
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 9 deletions.
4 changes: 4 additions & 0 deletions arrayjit/lib/backend_impl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ module Device_types (Device_config : Device_config) = struct
type nonrec context = (buffer_ptr, stream) context [@@deriving sexp_of]
end

let next_global_device_id : Utils.atomic_int = Atomic.make 0

module Device
(Device_types : Device_types)
(Alloc_buffer :
Expand All @@ -104,9 +106,11 @@ struct
include Alloc_buffer

let make_device dev ~ordinal =
let device_id = Atomic.fetch_and_add next_global_device_id 1 in
{
dev;
ordinal;
device_id;
cross_stream_candidates = Hashtbl.create (module Tnode);
owner_stream = Hashtbl.create (module Tnode);
shared_writer_streams = Hashtbl.create (module Tnode);
Expand Down
10 changes: 8 additions & 2 deletions arrayjit/lib/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ end
type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
dev : 'dev;
ordinal : int;
device_id : int;
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Hashtbl.M(Tnode).t;
shared_writer_streams :
Expand Down Expand Up @@ -111,6 +112,11 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device =
('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
dev : 'dev;
ordinal : int;
(** The number of the represented backend's device, in the range from 0 to the number of the
backend's devices - 1. *)
device_id : int;
(** A unique identifier among all device instances of all backends. Note that multiple
[device_id] (distinct device instances) might refer to the same physical device. *)
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
(** Freshly created arrays that might be shared across streams. The map can both grow and
shrink. *)
Expand Down Expand Up @@ -248,8 +254,8 @@ module type Backend_device_common = sig
val sync : event -> unit
(** Blocks till the event completes, if it's not done already.
It is rarely needed to call [sync] explicitly, because it should always be
called internally when necessary, in particular before extracting values from host. *)
It is rarely needed to call [sync] explicitly, because it should always be called internally
when necessary, in particular before extracting values from host. *)

val is_done : event -> bool
(** Whether the event completed. *)
Expand Down
7 changes: 4 additions & 3 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
(* Stdio.printf "copying: %s from_host\n" (Tn.debug_name tn); *)
Backend.from_host ~dst_ptr:dst ~dst:ctx hosted;
update_writer_event ~from:`Host ctx @@ Node tn;
tn.host_modified <- false;
Hash_set.add tn.host_read_by_devices ctx.stream.device.device_id;
true
| _ -> false

Expand Down Expand Up @@ -146,7 +146,8 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
assert (Domain.is_main_domain ());
if Utils.settings.automatic_host_transfers then
Set.iter hosted_inputs ~f:(fun tn ->
if tn.host_modified then assert (from_host r.context tn));
if not (Hash_set.mem tn.host_read_by_devices s.device.device_id) then
assert (from_host r.context tn));
Set.iter r.inputs ~f:(fun tn ->
if Tn.potentially_cross_stream tn then
Option.iter (Hashtbl.find s.device.shared_writer_streams tn) ~f:(fun data ->
Expand Down Expand Up @@ -386,7 +387,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
match key.array with
| (lazy (Some hosted)) ->
Device.from_host ~dst_ptr ~dst:parent_context hosted;
key.host_modified <- false
Hash_set.add key.host_read_by_devices stream.device.device_id
| _ -> ());
dst_ptr
in
Expand Down
9 changes: 5 additions & 4 deletions arrayjit/lib/tnode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ type t = {
mutable code_name : string option;
mutable prepare_read : prepare option;
mutable prepare_write : prepare option;
mutable host_modified : bool;
mutable host_read_by_devices : Hash_set.M(Int).t;
(** The unique ids of devices that read the most recent modification of the host array. *)
}
[@@deriving sexp_of]

Expand Down Expand Up @@ -554,7 +555,7 @@ let create ?default_prec ~id ~label ~dims init_op =
code_name = None;
prepare_read = None;
prepare_write = None;
host_modified = true;
host_read_by_devices = Hash_set.create (module Int);
}
in
(* Note: if tensor nodes get non-trivial finalizers, remember to either add an is_finalized flag
Expand All @@ -578,7 +579,7 @@ let find =
code_name = None;
prepare_read = None;
prepare_write = None;
host_modified = false;
host_read_by_devices = Hash_set.create (module Int);
}
in
fun ~id -> Registry.find_opt registry { mock with id }
Expand All @@ -596,7 +597,7 @@ let do_read tn =
let do_write tn =
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_write;
tn.prepare_write <- None;
tn.host_modified <- true
Hash_set.clear tn.host_read_by_devices

let points_1d ?from_axis ~xdim tn =
do_read tn;
Expand Down

0 comments on commit 7d333cd

Please sign in to comment.