Skip to content

Commit

Permalink
Fixes the memory model: on-host arrays can be in contexts
Browse files Browse the repository at this point in the history
Double check: not possible it would trigger freeing host array.
Still broken: cc backend tests hang.
  • Loading branch information
lukstafi committed Dec 10, 2024
1 parent 25c71e5 commit d4277b2
Show file tree
Hide file tree
Showing 16 changed files with 196 additions and 169 deletions.
13 changes: 10 additions & 3 deletions arrayjit/lib/anatomy_of_a_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ In the future, when we introduce program search, `compile` functions will return
OCANNL classifies tensor nodes according to their memory properties:

```ocaml
(** A possible algorithm for deciding sharing within a single device:
(** A possible algorithm for deciding sharing within a single device:
- If a tensor node is read-only for a context, and not otherwise recorded, it is stored as a
cross-stream sharing candidate.
- If a cross-stream sharing candidate is read-only for another context, whose parent does not
Expand All @@ -71,9 +71,14 @@ OCANNL classifies tensor nodes according to their memory properties:
If a tensor node is shared cross-stream, within-device copying is a NOOP as source and
destination pointers are in that case identical. *)
type sharing =
| Unset
| Unset (** One of: [Per_stream], [Shared_cross_streams]. *)
| Per_stream (** The tensor node has separate arrays for each stream. *)
| Shared_cross_stream (** The tensor node has a single array per device. *)
| Shared_cross_streams
(** The tensor node has a single array per device that can appear in multiple contexts, except
for backends with [Option.is_some use_host_memory] and nodes with memory mode already
[Hosted (Changed_on_devices Shared_cross_streams)] before first linking on a device, where
it only has the on-host array. In that case the on-host array is registered in the
context, to avoid misleading behavior from `device_to_device`. *)
type memory_type =
| Constant (** The tensor node does not change after initialization. *)
Expand Down Expand Up @@ -110,6 +115,8 @@ A backend can make more refined distinctions, for example a `Local` node in CUDA

Contexts track (or store) the on-device arrays corresponding to tensor nodes. Contexts form a hierarchy: linking takes a parent context and outputs a child context. Related contexts that use a tensor node must use the same on-device array for the tensor node. If two unrelated contexts are on the same device, i.e. have a common ancestor, and use the same tensor node that is not part of the most recent common ancestor, the behavior is undefined.

To avoid misleading behavior of `device_to_device` data movement, non-constant materialized tensor nodes are represented in contexts making use of them, even when the underlying array is on host. This way the logic remains the same regardless of whether a backend shares memory with the host. We are careful to not accidentally call `free_buffer` on hosted arrays.

## Typical details of a backend implementation

During the compilation process, the old context cannot be available when `compile` is handled. Currently, all backends generate context-and-device-independent kernels, that refer to context arrays via parameters.
Expand Down
11 changes: 4 additions & 7 deletions arrayjit/lib/assignments.ml
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,13 @@ let get_name_exn asgns =
let is_total ~initialize_neutral ~projections =
initialize_neutral && Indexing.is_bijective projections

(** Returns materialized nodes in the sense of {!Tnode.is_in_context}. NOTE: it should be called
after compilation and ideally after linking with the relevant contexts; otherwise, it is an
under-estimate. *)
let%debug3_sexp context_nodes ~(use_host_memory : bool) (asgns : t) : Tn.t_set =
(** Returns materialized nodes in the sense of {!Tnode.is_in_context_force}. NOTE: it must be called
after compilation; otherwise, it will disrupt memory mode inference. *)
let%debug3_sexp context_nodes ~(use_host_memory : 'a option) (asgns : t) : Tn.t_set =
let open Utils.Set_O in
let empty = Set.empty (module Tn) in
let one tn =
if Option.value ~default:false @@ Tnode.is_in_context ~use_host_memory tn then
Set.singleton (module Tn) tn
else empty
if Tn.is_in_context_force ~use_host_memory tn 34 then Set.singleton (module Tn) tn else empty
in
let of_node = function Node rhs -> one rhs | Merge_buffer _ -> empty in
let rec loop = function
Expand Down
12 changes: 7 additions & 5 deletions arrayjit/lib/backend_impl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ open Backend_intf
module type No_device_buffer_and_copying = sig
include Alloc_buffer with type stream := unit

val use_host_memory : (unit Ctypes.ptr -> buffer_ptr) option

val get_used_memory : unit -> int
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)

Expand All @@ -28,6 +30,7 @@ module No_device_buffer_and_copying () :
No_device_buffer_and_copying with type buffer_ptr = unit Ctypes.ptr = struct
type buffer_ptr = unit Ctypes.ptr

let use_host_memory = Some Fn.id
let sexp_of_buffer_ptr = Ops.sexp_of_voidptr

include Buffer_types (struct
Expand Down Expand Up @@ -70,8 +73,6 @@ module No_device_buffer_and_copying () :
Ctypes_memory_stubs.memcpy
~dst:(Ndarray.get_fatptr_not_managed dst)
~src ~size:(Ndarray.size_in_bytes dst)

let c_ptr_to_string = Some Ops.c_ptr_to_string
end

module Device_types (Device_config : Device_config) = struct
Expand Down Expand Up @@ -133,10 +134,11 @@ end
module type Backend_impl_common = sig
include Buffer

val use_host_memory : bool
(** If true, the backend will read from and write to the host memory directly whenever possible.
val use_host_memory : (unit Ctypes.ptr -> buffer_ptr) option
(** If not [None], the backend will read from and write to the host memory directly whenever
reasonable.
[use_host_memory] can only be true on unified memory devices, like CPU and Apple Metal. *)
[use_host_memory] can only be [Some] on unified memory devices, like CPU and Apple Metal. *)
end

(** An interface to adding schedulers for stream-agnostic (typically CPU) backend implementations.
Expand Down
2 changes: 0 additions & 2 deletions arrayjit/lib/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ end
module type Buffer = sig
type buffer_ptr [@@deriving sexp_of]

val c_ptr_to_string : (buffer_ptr -> Ops.prec -> string) option

include module type of Buffer_types (struct
type nonrec buffer_ptr = buffer_ptr [@@deriving sexp_of]
end)
Expand Down
35 changes: 21 additions & 14 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
let ordinal_of ctx = ctx.stream.device.ordinal in
let name_of ctx = Backend.(get_name ctx.stream) in
let same_device = ordinal_of dst = ordinal_of src in
if same_device && (Tn.known_shared_cross_stream tn || String.equal (name_of src) (name_of dst))
if same_device && (Tn.known_shared_cross_streams tn || String.equal (name_of src) (name_of dst))
then false
else
match Map.find src.ctx_arrays tn with
Expand Down Expand Up @@ -187,8 +187,7 @@ let lower_batch_assignments ?names ?occupancy bindings asgns_l =
let%debug3_sexp verify_prior_context ~use_host_memory ~ctx_arrays ~from_prior_context : unit =
Set.iter from_prior_context ~f:(fun tn ->
if
(* Err on the safe side. *)
Option.value ~default:false (Tn.is_in_context ~use_host_memory tn)
Tn.is_in_context_force ~use_host_memory tn 42
&& not (Option.is_some @@ Map.find ctx_arrays tn)
then raise @@ Utils.User_error ("The linked context lacks node " ^ Tnode.debug_name tn))

Expand Down Expand Up @@ -349,27 +348,35 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
}

let%track3_sexp alloc_if_needed (stream : stream) ~key ~data:node ctx_arrays =
(* TODO: do we need this? *)
(* Tn.default_to_most_local key 345; *)
if
Option.value ~default:true (Tnode.is_in_context ~use_host_memory key)
&& not (Map.mem ctx_arrays key)
then (
if Tnode.is_in_context_force ~use_host_memory key 43 && not (Map.mem ctx_arrays key) then (
[%log Tn.debug_name key];
[%log (key : Tnode.t)];
let default () =
alloc_zero_init_array (Lazy.force key.prec) ~dims:(Lazy.force key.dims) stream
in
let add_new () = Map.add_exn ctx_arrays ~key ~data:(default ()) in
let device = stream.device in
if node.Low_level.read_only then
if node.Low_level.read_only then (
if Tn.known_non_cross_stream key then add_new ()
else (
else
let data =
match use_host_memory with
| None -> Hashtbl.find_or_add device.cross_stream_candidates key ~default
| Some get_buffer_ptr ->
if
(not (Hashtbl.mem device.cross_stream_candidates key))
&& Tn.known_shared_cross_streams key && Tn.is_hosted_force key 44
then
Hashtbl.update_and_return device.cross_stream_candidates key ~f:(fun _ ->
get_buffer_ptr @@ Ndarray.get_voidptr_not_managed
@@ Option.value_exn ~here:[%here]
@@ Lazy.force key.array)
else Hashtbl.find_or_add device.cross_stream_candidates key ~default
in
if Hashtbl.mem device.cross_stream_candidates key then
Tn.update_memory_sharing key Tn.Shared_cross_stream 39;
let data = Hashtbl.find_or_add device.cross_stream_candidates key ~default in
Tn.update_memory_sharing key Tn.Shared_cross_streams 39;
Map.add_exn ctx_arrays ~key ~data)
else if Tn.known_shared_cross_stream key then (
else if Tn.known_shared_cross_streams key then (
if Hashtbl.mem device.owner_stream key then (
if not (equal_stream stream (Hashtbl.find_exn device.owner_stream key)) then
raise
Expand Down
10 changes: 6 additions & 4 deletions arrayjit/lib/c_syntax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ module C_syntax (B : sig
(** The low-level prcedure to compile, and the arrays of the context it will be linked to if not
shared and already known. *)

val use_host_memory : bool
type buffer_ptr

val use_host_memory : (unit Ctypes.ptr -> buffer_ptr) option
val logs_to_stdout : bool
val main_kernel_prefix : string
val kernel_prep_line : string
Expand All @@ -29,7 +31,7 @@ struct
let get_ident =
Low_level.get_ident_within_code ~no_dots:true @@ Array.map B.procs ~f:(fun l -> l.llc)

let in_ctx tn = B.(Tn.is_in_context ~use_host_memory tn)
let in_ctx tn = B.(Tn.is_in_context_force ~use_host_memory tn 46)

let pp_zero_out ppf tn =
Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " (get_ident tn) @@ Tn.size_in_bytes tn
Expand Down Expand Up @@ -268,8 +270,8 @@ struct
(* A rough approximation to the type Gccjit_backend.mem_properties. *)
let backend_info, is_param =
if Tn.is_virtual_force tn 334 then ("Virt", false)
else if Option.value ~default:false @@ in_ctx tn then ("Ctx", true)
else if Tn.is_materialized_force tn 335 then ("Global or ctx", true)
else if in_ctx tn then ("Ctx", true)
else if Tn.is_materialized_force tn 335 then ("Global", true)
else if Tn.known_not_materialized tn then ("Local", false)
else assert false
in
Expand Down
29 changes: 14 additions & 15 deletions arrayjit/lib/cc_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ type procedure = {
}
[@@deriving sexp_of]

let use_host_memory = true

let get_global_run_id =
let next_id = ref 0 in
fun () ->
Expand Down Expand Up @@ -76,6 +74,9 @@ module C_syntax_config (Input : sig
end) =
struct
let procs = Input.procs

type nonrec buffer_ptr = buffer_ptr

let use_host_memory = use_host_memory
let logs_to_stdout = false
let main_kernel_prefix = ""
Expand Down Expand Up @@ -127,14 +128,6 @@ let%diagn_sexp compile_batch ~names bindings (lowereds : Low_level.optimized opt

let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : procedure) =
let name : string = code.name in
List.iter code.params ~f:(function
| _, Param_ptr tn when not @@ Tn.known_shared_with_host ~use_host_memory tn ->
if not (Map.mem ctx_arrays tn) then
invalid_arg
[%string
"Cc_backend.link_compiled: node %{Tn.debug_name tn} missing from context: \
%{Tn.debug_memory_mode tn.Tn.memory_mode}"]
| _ -> ());
let log_file_name = Utils.diagn_log_file [%string "debug-%{runner_label}-%{code.name}.log"] in
let run_variadic =
[%log_level
Expand All @@ -158,11 +151,16 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro
Param_2f (get_ptr, merge_buffer, link bs ps Ctypes.(ptr void @-> cs))
| bs, Param_ptr tn :: ps ->
let c_ptr =
if Tn.known_shared_with_host ~use_host_memory tn then
Ndarray.get_voidptr_not_managed
@@ Option.value_exn ~here:[%here]
@@ Lazy.force tn.array
else Map.find_exn ctx_arrays tn
match Map.find ctx_arrays tn with
| None ->
Ndarray.get_voidptr_not_managed
@@ Option.value_exn ~here:[%here]
~message:
[%string
"Cc_backend.link_compiled: node %{Tn.debug_name tn} missing from \
context: %{Tn.debug_memory_mode tn.Tn.memory_mode}"]
@@ Lazy.force tn.array
| Some arr -> arr
in
Param_2 (ref (Some c_ptr), link bs ps Ctypes.(ptr void @-> cs))
in
Expand All @@ -174,6 +172,7 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro
in
let%diagn_l_sexp work () : unit =
[%log_result name];
(* Stdio.printf "launching %s\n" name; *)
Indexing.apply run_variadic ();
if Utils.debug_log_from_routines () then (
Utils.log_trace_tree (Stdio.In_channel.read_lines log_file_name);
Expand Down
17 changes: 10 additions & 7 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@ module Backend_buffer = struct
type buffer_ptr = Cu.Deviceptr.t

let sexp_of_buffer_ptr ptr = Sexp.Atom (Cu.Deviceptr.string_of ptr)
let c_ptr_to_string = None

include Buffer_types (struct
type nonrec buffer_ptr = buffer_ptr [@@deriving sexp_of]
end)
end

let use_host_memory = false
let use_host_memory = None

module Device_config = struct
include Backend_buffer
Expand Down Expand Up @@ -156,15 +155,18 @@ let%track3_sexp new_stream (device : device) : stream =

let cuda_properties =
let cache =
lazy
(Array.init (num_devices ()) ~f:(fun ordinal ->
let dev = get_device ~ordinal in
lazy (Cu.Device.get_attributes dev.dev.dev)))
let%debug2_sexp f (ordinal : int) =
let dev = get_device ~ordinal in
lazy (Cu.Device.get_attributes dev.dev.dev)
in
lazy (Array.init (num_devices ()) ~f)
in
fun device ->
let%debug2_sexp get_props (device : device) : Cu.Device.attributes =
if not @@ is_initialized () then invalid_arg "cuda_properties: CUDA not initialized";
let cache = Lazy.force cache in
Lazy.force cache.(device.ordinal)
in
get_props

let suggested_num_streams device =
match !global_config with
Expand Down Expand Up @@ -427,6 +429,7 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx
(* Map.iteri ctx_arrays ~f:(fun ~key ~data:ptr -> if key.Low_level.zero_initialized then
Cu.Stream.memset_d8 ptr Unsigned.UChar.zero ~length:(Tn.size_in_bytes key.Low_level.tn)); *)
[%log "launching the kernel"];
(* Stdio.printf "launching %s\n" name; *)
(if Utils.debug_log_from_routines () then
Utils.add_log_processor ~prefix:log_id_prefix @@ fun _output ->
[%log_block
Expand Down
28 changes: 17 additions & 11 deletions arrayjit/lib/gcc_backend.gccjit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ type procedure = {
}
[@@deriving sexp_of]

let use_host_memory = true
(* let use_host_memory = true *)

let gcc_typ_of_prec =
let open Gccjit in
Expand Down Expand Up @@ -114,14 +114,14 @@ let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~param_ptrs ini
let ptr_typ = Type.pointer num_typ in
let ident = get_ident tn in
let hosted = Tn.is_hosted_force tn 344 in
let in_ctx = Tn.is_in_context ~use_host_memory tn in
let in_ctx = Tn.is_in_context_force ~use_host_memory tn 45 in
let ptr =
match (in_ctx, hosted) with
| Some true, _ ->
| true, _ ->
let p = Param.create ctx ptr_typ ident in
param_ptrs := (p, Param_ptr tn) :: !param_ptrs;
Lazy.from_val (RValue.param p)
| (Some false | None), true -> (
| false, true -> (
let addr arr =
Lazy.from_val @@ get_c_ptr ctx num_typ @@ Ctypes.bigarray_start Ctypes_static.Genarray arr
in
Expand All @@ -131,7 +131,7 @@ let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~param_ptrs ini
| Some (Single_nd arr) -> addr arr
| Some (Double_nd arr) -> addr arr
| None -> assert false)
| (Some false | None), false ->
| false, false ->
let arr_typ = Type.array ctx num_typ size_in_elems in
let v = ref None in
let initialize _init_block func = v := Some (Function.local func arr_typ ident) in
Expand Down Expand Up @@ -644,7 +644,8 @@ let%diagn_sexp compile_batch ~(names : string option array) bindings
let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : procedure) =
let name : string = code.name in
List.iter code.params ~f:(function
| Param_ptr tn when not (Tn.known_shared_cross_stream tn) -> assert (Map.mem ctx_arrays tn)
(* FIXME: see cc_backend.ml *)
| Param_ptr tn when not (Tn.known_shared_cross_streams tn) -> assert (Map.mem ctx_arrays tn)
| _ -> ());
let log_file_name = Utils.diagn_log_file [%string "debug-%{runner_label}-%{code.name}.log"] in
let run_variadic =
Expand All @@ -667,11 +668,16 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro
Param_1 (ref (Some log_file_name), link bs ps Ctypes.(string @-> cs))
| bs, Param_ptr tn :: ps ->
let c_ptr =
if Tn.known_shared_with_host ~use_host_memory tn then
Ndarray.get_voidptr_not_managed
@@ Option.value_exn ~here:[%here]
@@ Lazy.force tn.array
else Map.find_exn ctx_arrays tn
match Map.find ctx_arrays tn with
| None ->
Ndarray.get_voidptr_not_managed
@@ Option.value_exn ~here:[%here]
~message:
[%string
"Gcc_backend.link_compiled: node %{Tn.debug_name tn} missing from \
context: %{Tn.debug_memory_mode tn.Tn.memory_mode}"]
@@ Lazy.force tn.array
| Some arr -> arr
in
Param_2 (ref (Some c_ptr), link bs ps Ctypes.(ptr void @-> cs))
| bs, Merge_buffer :: ps ->
Expand Down
3 changes: 1 addition & 2 deletions arrayjit/lib/lowered_backend_missing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ type dev
type runner
type event

let use_host_memory = false
let use_host_memory = None
let sexp_of_dev _dev = failwith "Backend missing -- install the corresponding library"
let sexp_of_runner _runner = failwith "Backend missing -- install the corresponding library"
let sexp_of_event _event = failwith "Backend missing -- install the corresponding library"
Expand Down Expand Up @@ -39,7 +39,6 @@ let make_child ?ctx_arrays:_ _context =

let get_name _stream = failwith "Backend missing -- install the corresponding library"
let sexp_of_buffer_ptr _buffer_ptr = failwith "Backend missing -- install the corresponding library"
let c_ptr_to_string = None

type nonrec buffer = buffer_ptr Backend_intf.buffer

Expand Down
Loading

0 comments on commit d4277b2

Please sign in to comment.