Skip to content

Commit

Permalink
In progress: get rid of hard-coded pointers, and of opt_ctx_arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Dec 8, 2024
1 parent b5d6104 commit 606f3d2
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 104 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- Huge refactoring of backend internal interfaces and API (not repeating same code).
- Built per-tensor-node stream-to-stream synchronization into copying functions.
- Re-introduced whole-device blocking synchronization, which now is just a slight optimization as it also cleans up event book-keeping.
- Simplifications: no more explicit compilation postponing; no more hard-coded pointers (all non-local arrays are passed by parameter).

### Fixed

Expand Down
12 changes: 1 addition & 11 deletions arrayjit/lib/backend_impl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -158,23 +158,13 @@ module type Lowered_no_device_backend = sig

type procedure [@@deriving sexp_of]

val compile :
name:string ->
opt_ctx_arrays:ctx_arrays option ->
Indexing.unit_bindings ->
Low_level.optimized ->
procedure
(** [opt_ctx_arrays], if any, already contain the arrays of the context that will result from
linking the code. *)
val compile : name:string -> Indexing.unit_bindings -> Low_level.optimized -> procedure

val compile_batch :
names:string option array ->
opt_ctx_arrays:ctx_arrays option array option ->
Indexing.unit_bindings ->
Low_level.optimized option array ->
procedure option array
(** [opt_ctx_arrays], if any, already contain the arrays of the contexts that will result from
linking the code. *)

val link_compiled :
merge_buffer:buffer option ref ->
Expand Down
4 changes: 2 additions & 2 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,11 @@ module Add_device
[@@deriving sexp_of]

let compile ~name bindings lowered : code =
let proc = compile ~name ~opt_ctx_arrays:None bindings lowered in
let proc = compile ~name bindings lowered in
{ lowered; proc }

let compile_batch ~names bindings lowereds : code_batch =
let procs = compile_batch ~names ~opt_ctx_arrays:None bindings lowereds in
let procs = compile_batch ~names bindings lowereds in
{ lowereds; procs }

include Add_scheduler (Backend)
Expand Down
57 changes: 12 additions & 45 deletions arrayjit/lib/c_syntax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,31 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
module Tn = Tnode

module C_syntax (B : sig
type buffer_ptr

val procs : (Low_level.optimized * buffer_ptr ctx_arrays option) array
val procs : Low_level.optimized array
(** The low-level prcedure to compile, and the arrays of the context it will be linked to if not
shared and already known. *)

val hardcoded_context_ptr : (buffer_ptr -> Ops.prec -> string) option
val use_host_memory : bool
val logs_to_stdout : bool
val main_kernel_prefix : string
val kernel_prep_line : string
val include_lines : string list
val includes : string list
val typ_of_prec : Ops.prec -> string
val binop_syntax : Ops.prec -> Ops.binop -> string * string * string
val unop_syntax : Ops.prec -> Ops.unop -> string * string
val convert_precision : from:Ops.prec -> to_:Ops.prec -> string * string
end) =
struct
let get_ident =
Low_level.get_ident_within_code ~no_dots:true @@ Array.map B.procs ~f:(fun (l, _) -> l.llc)
Low_level.get_ident_within_code ~no_dots:true @@ Array.map B.procs ~f:(fun l -> l.llc)

let in_ctx tn = B.(Tn.is_in_context ~use_host_memory tn)

let pp_zero_out ppf tn =
Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " (get_ident tn) @@ Tn.size_in_bytes tn

let pp_include ppf s = Stdlib.Format.fprintf ppf "#include %s" s

open Indexing.Pp_helpers

let pp_array_offset ppf (idcs, dims) =
Expand All @@ -61,33 +60,8 @@ struct

(* let compute_array_offset ~idcs ~dims = Array.fold2_exn idcs dims ~init:0 ~f:(fun offset idx dim
-> idx + (offset * dim)) *)
let%debug3_sexp compile_globals ppf : Tn.t Hash_set.t =
let open Stdlib.Format in
let is_global = Hash_set.create (module Tn) in
fprintf ppf {|@[<v 0>%a@,/* Global declarations. */@,|} (pp_print_list pp_print_string)
B.include_lines;
Array.iter B.procs ~f:(fun (l, ctx_arrays) ->
Hashtbl.iter l.Low_level.traced_store ~f:(fun (node : Low_level.traced_array) ->
let tn = node.tn in
if not @@ Hash_set.mem is_global tn then
let ctx_ptr = B.hardcoded_context_ptr in
let mem : (Tn.memory_mode * int) option = tn.memory_mode in
match (in_ctx tn, ctx_ptr, ctx_arrays, mem) with
| Some true, Some get_ptr, Some ctx_arrays, _ ->
let ident = get_ident tn in
let ctx_array =
Option.value_exn ~here:[%here] ~message:ident @@ Map.find ctx_arrays tn
in
fprintf ppf "#define %s (%s)@," ident @@ get_ptr ctx_array (Lazy.force tn.prec);
Hash_set.add is_global tn
| Some false, _, _, Some (Hosted _, _)
when B.(Tn.known_shared_with_host ~use_host_memory tn) ->
let nd = Option.value_exn ~here:[%here] @@ Lazy.force tn.array in
fprintf ppf "#define %s (%s)@," (get_ident tn) (Ndarray.c_ptr_to_string nd);
Hash_set.add is_global tn
| _ -> ()));
fprintf ppf "@,@]";
is_global
let print_includes ppf =
Stdlib.Format.(fprintf ppf {|@[<v 0>%a@,|} (pp_print_list pp_include) B.includes)
let compile_main ~traced_store ppf llc : unit =
let open Stdlib.Format in
Expand Down Expand Up @@ -285,18 +259,16 @@ struct
in
pp_ll ppf llc
let%track3_sexp compile_proc ~name ppf idx_params ~is_global
Low_level.{ traced_store; llc; merge_node } =
let%track3_sexp compile_proc ~name ppf idx_params Low_level.{ traced_store; llc; merge_node } =
let open Stdlib.Format in
let params : (string * param_source) list =
(* Preserve the order in the hashtable, so it's the same as e.g. in compile_globals. *)
(* Preserve the order in the hashtable. *)
List.rev
@@ Hashtbl.fold traced_store ~init:[] ~f:(fun ~key:tn ~data:_ params ->
(* A rough approximation to the type Gccjit_backend.mem_properties. *)
let backend_info =
Sexp.Atom
(if Hash_set.mem is_global tn then "Host"
else if Tn.is_virtual_force tn 334 then "Virt"
(if Tn.is_virtual_force tn 334 then "Virt"
else
match in_ctx tn with
| Some true -> "Ctx"
Expand All @@ -307,7 +279,7 @@ struct
tn.backend_info <- Utils.sexp_append ~elem:backend_info tn.backend_info;
(* We often don't know ahead of linking with relevant contexts what the stream sharing
mode of the node will become. Conservatively, use passing as argument. *)
if Option.value ~default:true (in_ctx tn) && not (Hash_set.mem is_global tn) then
if Option.value ~default:true (in_ctx tn) then
(B.typ_of_prec (Lazy.force tn.Tn.prec) ^ " *" ^ get_ident tn, Param_ptr tn) :: params
else params)
in
Expand Down Expand Up @@ -373,12 +345,7 @@ struct
params);
fprintf ppf "/* Local declarations and initialization. */@ ";
Hashtbl.iteri traced_store ~f:(fun ~key:tn ~data:node ->
if
not
(Tn.is_virtual_force tn 333
|| Option.value ~default:true (in_ctx tn)
|| Hash_set.mem is_global tn)
then
if not (Tn.is_virtual_force tn 333 || Option.value ~default:true (in_ctx tn)) then
fprintf ppf "%s %s[%d]%s;@ "
(B.typ_of_prec @@ Lazy.force tn.prec)
(get_ident tn) (Tn.num_elems tn)
Expand Down
30 changes: 10 additions & 20 deletions arrayjit/lib/cc_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -72,47 +72,37 @@ let c_compile_and_load ~f_name =
result

module C_syntax_config (Input : sig
val procs : (Low_level.optimized * buffer_ptr ctx_arrays option) array
val procs : Low_level.optimized array
end) =
struct
type nonrec buffer_ptr = buffer_ptr

let procs = Input.procs
let hardcoded_context_ptr = c_ptr_to_string
let use_host_memory = use_host_memory
let logs_to_stdout = false
let main_kernel_prefix = ""
let kernel_prep_line = ""

let include_lines =
[ "#include <stdio.h>"; "#include <stdlib.h>"; "#include <string.h>"; "#include <math.h>" ]

let includes = [ "<stdio.h>"; "<stdlib.h>"; "<string.h>"; "<math.h>" ]
let typ_of_prec = Ops.c_typ_of_prec
let binop_syntax = Ops.binop_c_syntax
let unop_syntax = Ops.unop_c_syntax
let convert_precision = Ops.c_convert_precision
end

let%diagn_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_level.optimized) =
let%diagn_sexp compile ~(name : string) bindings (lowered : Low_level.optimized) =
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
let procs = [| (lowered, opt_ctx_arrays) |]
let procs = [| lowered |]
end)) in
(* FIXME: do we really want all of them, or only the used ones? *)
let idx_params = Indexing.bound_symbols bindings in
let pp_file = Utils.pp_file ~base_name:name ~extension:".c" in
let is_global = Syntax.compile_globals pp_file.ppf in
let params = Syntax.compile_proc ~name pp_file.ppf idx_params ~is_global lowered in
Syntax.print_includes pp_file.ppf;
let params = Syntax.compile_proc ~name pp_file.ppf idx_params lowered in
pp_file.finalize ();
let result = c_compile_and_load ~f_name:pp_file.f_name in
{ result; params; bindings; name }

let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
(lowereds : Low_level.optimized option array) =
let%diagn_sexp compile_batch ~names bindings (lowereds : Low_level.optimized option array) =
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
let procs =
Array.filter_mapi lowereds ~f:(fun i ->
Option.map ~f:(fun lowereds ->
(lowereds, Option.(map opt_ctx_arrays ~f:(fun ctx_arrays -> value_exn ctx_arrays.(i))))))
let procs = Array.filter_opt lowereds
end)) in
(* FIXME: do we really want all of them, or only the used ones? *)
let idx_params = Indexing.bound_symbols bindings in
Expand All @@ -122,11 +112,11 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
@@ common_prefix (Array.to_list @@ Array.concat_map ~f:Option.to_array names))
in
let pp_file = Utils.pp_file ~base_name ~extension:".c" in
let is_global = Syntax.compile_globals pp_file.ppf in
Syntax.print_includes pp_file.ppf;
let params =
Array.mapi lowereds ~f:(fun i lowered ->
Option.map2 names.(i) lowered ~f:(fun name lowered ->
Syntax.compile_proc ~name pp_file.ppf idx_params ~is_global lowered))
Syntax.compile_proc ~name pp_file.ppf idx_params lowered))
in
pp_file.finalize ();
let result = c_compile_and_load ~f_name:pp_file.f_name in
Expand Down
17 changes: 8 additions & 9 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -267,21 +267,20 @@ let%diagn2_sexp cuda_to_ptx ~name cu_src =
ptx

module C_syntax_config (Input : sig
val procs : (Low_level.optimized * ctx_arrays option) array
val procs : Low_level.optimized array
end) =
struct
type nonrec buffer_ptr = buffer_ptr [@@deriving sexp_of]

let procs = Input.procs
let hardcoded_context_ptr = None
let use_host_memory = use_host_memory
let logs_to_stdout = true
let main_kernel_prefix = "extern \"C\" __global__"

let kernel_prep_line =
"/* FIXME: single-threaded for now. */if (threadIdx.x != 0 || blockIdx.x != 0) { return; }"

let include_lines = [ "#include <cuda_fp16.h>" ]
let includes = [ "<cuda_fp16.h>" ]

let typ_of_prec = function
| Ops.Byte_prec _ -> "unsigned char"
Expand Down Expand Up @@ -341,31 +340,31 @@ let compile ~name bindings ({ Low_level.traced_store; _ } as lowered) =
(* TODO: The following link seems to claim it's better to expand into loops than use memset.
https://stackoverflow.com/questions/23712558/how-do-i-best-initialize-a-local-memory-array-to-0 *)
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
let procs = [| (lowered, None) |]
let procs = [| lowered |]
end)) in
let idx_params = Indexing.bound_symbols bindings in
let b = Buffer.create 4096 in
let ppf = Stdlib.Format.formatter_of_buffer b in
if Utils.debug_log_from_routines () then
Stdlib.Format.fprintf ppf "@,__device__ int printf (const char * format, ... );@,";
let is_global = Syntax.compile_globals ppf in
let params = Syntax.compile_proc ~name ~is_global ppf idx_params lowered in
Syntax.print_includes ppf;
let params = Syntax.compile_proc ~name ppf idx_params lowered in
let ptx = cuda_to_ptx ~name @@ Buffer.contents b in
{ traced_store; ptx; params; bindings; name }

let compile_batch ~names bindings lowereds =
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
let procs = Array.filter_map lowereds ~f:(Option.map ~f:(fun lowereds -> (lowereds, None)))
let procs = Array.filter_opt lowereds
end)) in
let idx_params = Indexing.bound_symbols bindings in
let b = Buffer.create 4096 in
let ppf = Stdlib.Format.formatter_of_buffer b in
let is_global = Syntax.compile_globals ppf in
Syntax.print_includes ppf;
let params_and_names =
Array.map2_exn names lowereds
~f:
(Option.map2 ~f:(fun name lowered ->
(Syntax.compile_proc ~name ~is_global ppf idx_params lowered, name)))
(Syntax.compile_proc ~name ppf idx_params lowered, name)))
in
let name : string =
String.(
Expand Down
27 changes: 12 additions & 15 deletions arrayjit/lib/gcc_backend.gccjit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ let zero_out ctx block node =

let get_c_ptr ctx num_typ ptr = Gccjit.(RValue.ptr ctx (Type.pointer num_typ) ptr)

let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~opt_ctx_arrays ~param_ptrs
initializations (tn : Tn.t) =
let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~param_ptrs initializations
(tn : Tn.t) =
let open Gccjit in
let traced = Low_level.(get_node traced_store tn) in
let dims = Lazy.force tn.dims in
Expand All @@ -116,14 +116,12 @@ let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~opt_ctx_arrays
let hosted = Tn.is_hosted_force tn 344 in
let in_ctx = Tn.is_in_context ~use_host_memory tn in
let ptr =
match (in_ctx, opt_ctx_arrays, hosted) with
| Some true, Some ctx_arrays, _ ->
Lazy.from_val @@ get_c_ptr ctx num_typ @@ Map.find_exn ctx_arrays tn
| (Some true | None), None, _ ->
match (in_ctx, hosted) with
| Some true, _ ->
let p = Param.create ctx ptr_typ ident in
param_ptrs := (p, Param_ptr tn) :: !param_ptrs;
Lazy.from_val (RValue.param p)
| (Some false | None), _, true -> (
| (Some false | None), true -> (
let addr arr =
Lazy.from_val @@ get_c_ptr ctx num_typ @@ Ctypes.bigarray_start Ctypes_static.Genarray arr
in
Expand All @@ -133,7 +131,7 @@ let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~opt_ctx_arrays
| Some (Single_nd arr) -> addr arr
| Some (Double_nd arr) -> addr arr
| None -> assert false)
| (Some false | None), _, false ->
| (Some false | None), false ->
let arr_typ = Type.array ctx num_typ size_in_elems in
let v = ref None in
let initialize _init_block func = v := Some (Function.local func arr_typ ident) in
Expand Down Expand Up @@ -500,7 +498,7 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node;
loop_proc ~toplevel:true ~name ~env body;
!current_block

let%diagn_sexp compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident
let%diagn_sexp compile_proc ~name ctx bindings ~get_ident
Low_level.{ traced_store; llc = proc; merge_node } =
let open Gccjit in
let c_index = Type.get ctx Type.Int in
Expand Down Expand Up @@ -536,7 +534,7 @@ let%diagn_sexp compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident
let data =
prepare_node
~debug_log_zero_out:(debug_log_zero_out ctx log_functions get_ident)
~get_ident ctx traced_store ~opt_ctx_arrays ~param_ptrs initializations tn
~get_ident ctx traced_store ~param_ptrs initializations tn
in
Hashtbl.add_exn nodes ~key:tn ~data);
let params : (gccjit_param * param_source) list = !param_ptrs in
Expand Down Expand Up @@ -590,7 +588,7 @@ let%diagn_sexp compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident
Block.return_void after_proc;
(ctx_info, params)

let compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_level.optimized) =
let compile ~(name : string) bindings (lowered : Low_level.optimized) =
let get_ident = Low_level.get_ident_within_code ~no_dots:true [| lowered.llc |] in
let open Gccjit in
if Option.is_none !root_ctx then initialize ();
Expand All @@ -599,7 +597,7 @@ let compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_level.optim
(* if Utils.settings.with_debug && Utils.settings.output_debug_files_in_build_directory then (
Context.set_option ctx Context.Keep_intermediates true; Context.set_option ctx
Context.Dump_everything true); *)
let info, params = compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident lowered in
let info, params = compile_proc ~name ctx bindings ~get_ident lowered in
(if Utils.settings.output_debug_files_in_build_directory then
let f_name = Utils.build_file @@ name ^ "-gccjit-debug.c" in
Context.dump_to_file ctx ~update_locs:true f_name);
Expand All @@ -608,7 +606,7 @@ let compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_level.optim
Context.release ctx;
{ info; result; bindings; name; params = List.map ~f:snd params }

let%diagn_sexp compile_batch ~(names : string option array) ~opt_ctx_arrays bindings
let%diagn_sexp compile_batch ~(names : string option array) bindings
(lowereds : Low_level.optimized option array) =
let get_ident =
Low_level.get_ident_within_code ~no_dots:true
Expand All @@ -623,10 +621,9 @@ let%diagn_sexp compile_batch ~(names : string option array) ~opt_ctx_arrays bind
Context.Dump_everything true); *)
let funcs =
Array.mapi lowereds ~f:(fun i lowered ->
let opt_ctx_arrays = Option.(join @@ map opt_ctx_arrays ~f:(fun arrs -> arrs.(i))) in
match (names.(i), lowered) with
| Some name, Some lowered ->
let info, params = compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident lowered in
let info, params = compile_proc ~name ctx bindings ~get_ident lowered in
Some (info, params)
| _ -> None)
in
Expand Down
Loading

0 comments on commit 606f3d2

Please sign in to comment.