diff --git a/CHANGES.md b/CHANGES.md index 55c5c94e..51fa2963 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -22,6 +22,7 @@ - Huge refactoring of backend internal interfaces and API (not repeating same code). - Built per-tensor-node stream-to-stream synchronization into copying functions. - Re-introduced whole-device blocking synchronization, which now is just a slight optimization as it also cleans up event book-keeping. +- Simplifications: no more explicit compilation postponing; no more hard-coded pointers (all non-local arrays are passed by parameter). ### Fixed diff --git a/arrayjit/lib/backend_impl.ml b/arrayjit/lib/backend_impl.ml index 5d27443d..8574967e 100644 --- a/arrayjit/lib/backend_impl.ml +++ b/arrayjit/lib/backend_impl.ml @@ -158,23 +158,13 @@ module type Lowered_no_device_backend = sig type procedure [@@deriving sexp_of] - val compile : - name:string -> - opt_ctx_arrays:ctx_arrays option -> - Indexing.unit_bindings -> - Low_level.optimized -> - procedure - (** [opt_ctx_arrays], if any, already contain the arrays of the context that will result from - linking the code. *) + val compile : name:string -> Indexing.unit_bindings -> Low_level.optimized -> procedure val compile_batch : names:string option array -> - opt_ctx_arrays:ctx_arrays option array option -> Indexing.unit_bindings -> Low_level.optimized option array -> procedure option array - (** [opt_ctx_arrays], if any, already contain the arrays of the contexts that will result from - linking the code. *) val link_compiled : merge_buffer:buffer option ref -> diff --git a/arrayjit/lib/backends.ml b/arrayjit/lib/backends.ml index 194d1547..7085d8b4 100644 --- a/arrayjit/lib/backends.ml +++ b/arrayjit/lib/backends.ml @@ -218,11 +218,11 @@ module Add_device [@@deriving sexp_of] let compile ~name bindings lowered : code = - let proc = compile ~name ~opt_ctx_arrays:None bindings lowered in + let proc = compile ~name bindings lowered in { lowered; proc } let compile_batch ~names bindings lowereds : code_batch = - let procs = compile_batch ~names ~opt_ctx_arrays:None bindings lowereds in + let procs = compile_batch ~names bindings lowereds in { lowereds; procs } include Add_scheduler (Backend) diff --git a/arrayjit/lib/c_syntax.ml b/arrayjit/lib/c_syntax.ml index 6c0dc928..a7a418f8 100644 --- a/arrayjit/lib/c_syntax.ml +++ b/arrayjit/lib/c_syntax.ml @@ -11,18 +11,15 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime module Tn = Tnode module C_syntax (B : sig - type buffer_ptr - - val procs : (Low_level.optimized * buffer_ptr ctx_arrays option) array + val procs : Low_level.optimized array (** The low-level prcedure to compile, and the arrays of the context it will be linked to if not shared and already known. *) - val hardcoded_context_ptr : (buffer_ptr -> Ops.prec -> string) option val use_host_memory : bool val logs_to_stdout : bool val main_kernel_prefix : string val kernel_prep_line : string - val include_lines : string list + val includes : string list val typ_of_prec : Ops.prec -> string val binop_syntax : Ops.prec -> Ops.binop -> string * string * string val unop_syntax : Ops.prec -> Ops.unop -> string * string @@ -30,13 +27,15 @@ module C_syntax (B : sig end) = struct let get_ident = - Low_level.get_ident_within_code ~no_dots:true @@ Array.map B.procs ~f:(fun (l, _) -> l.llc) + Low_level.get_ident_within_code ~no_dots:true @@ Array.map B.procs ~f:(fun l -> l.llc) let in_ctx tn = B.(Tn.is_in_context ~use_host_memory tn) let pp_zero_out ppf tn = Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " (get_ident tn) @@ Tn.size_in_bytes tn + let pp_include ppf s = Stdlib.Format.fprintf ppf "#include %s" s + open Indexing.Pp_helpers let pp_array_offset ppf (idcs, dims) = @@ -61,33 +60,8 @@ struct (* let compute_array_offset ~idcs ~dims = Array.fold2_exn idcs dims ~init:0 ~f:(fun offset idx dim -> idx + (offset * dim)) *) - let%debug3_sexp compile_globals ppf : Tn.t Hash_set.t = - let open Stdlib.Format in - let is_global = Hash_set.create (module Tn) in - fprintf ppf {|@[%a@,/* Global declarations. */@,|} (pp_print_list pp_print_string) - B.include_lines; - Array.iter B.procs ~f:(fun (l, ctx_arrays) -> - Hashtbl.iter l.Low_level.traced_store ~f:(fun (node : Low_level.traced_array) -> - let tn = node.tn in - if not @@ Hash_set.mem is_global tn then - let ctx_ptr = B.hardcoded_context_ptr in - let mem : (Tn.memory_mode * int) option = tn.memory_mode in - match (in_ctx tn, ctx_ptr, ctx_arrays, mem) with - | Some true, Some get_ptr, Some ctx_arrays, _ -> - let ident = get_ident tn in - let ctx_array = - Option.value_exn ~here:[%here] ~message:ident @@ Map.find ctx_arrays tn - in - fprintf ppf "#define %s (%s)@," ident @@ get_ptr ctx_array (Lazy.force tn.prec); - Hash_set.add is_global tn - | Some false, _, _, Some (Hosted _, _) - when B.(Tn.known_shared_with_host ~use_host_memory tn) -> - let nd = Option.value_exn ~here:[%here] @@ Lazy.force tn.array in - fprintf ppf "#define %s (%s)@," (get_ident tn) (Ndarray.c_ptr_to_string nd); - Hash_set.add is_global tn - | _ -> ())); - fprintf ppf "@,@]"; - is_global + let print_includes ppf = + Stdlib.Format.(fprintf ppf {|@[%a@,|} (pp_print_list pp_include) B.includes) let compile_main ~traced_store ppf llc : unit = let open Stdlib.Format in @@ -285,18 +259,16 @@ struct in pp_ll ppf llc - let%track3_sexp compile_proc ~name ppf idx_params ~is_global - Low_level.{ traced_store; llc; merge_node } = + let%track3_sexp compile_proc ~name ppf idx_params Low_level.{ traced_store; llc; merge_node } = let open Stdlib.Format in let params : (string * param_source) list = - (* Preserve the order in the hashtable, so it's the same as e.g. in compile_globals. *) + (* Preserve the order in the hashtable. *) List.rev @@ Hashtbl.fold traced_store ~init:[] ~f:(fun ~key:tn ~data:_ params -> (* A rough approximation to the type Gccjit_backend.mem_properties. *) let backend_info = Sexp.Atom - (if Hash_set.mem is_global tn then "Host" - else if Tn.is_virtual_force tn 334 then "Virt" + (if Tn.is_virtual_force tn 334 then "Virt" else match in_ctx tn with | Some true -> "Ctx" @@ -307,7 +279,7 @@ struct tn.backend_info <- Utils.sexp_append ~elem:backend_info tn.backend_info; (* We often don't know ahead of linking with relevant contexts what the stream sharing mode of the node will become. Conservatively, use passing as argument. *) - if Option.value ~default:true (in_ctx tn) && not (Hash_set.mem is_global tn) then + if Option.value ~default:true (in_ctx tn) then (B.typ_of_prec (Lazy.force tn.Tn.prec) ^ " *" ^ get_ident tn, Param_ptr tn) :: params else params) in @@ -373,12 +345,7 @@ struct params); fprintf ppf "/* Local declarations and initialization. */@ "; Hashtbl.iteri traced_store ~f:(fun ~key:tn ~data:node -> - if - not - (Tn.is_virtual_force tn 333 - || Option.value ~default:true (in_ctx tn) - || Hash_set.mem is_global tn) - then + if not (Tn.is_virtual_force tn 333 || Option.value ~default:true (in_ctx tn)) then fprintf ppf "%s %s[%d]%s;@ " (B.typ_of_prec @@ Lazy.force tn.prec) (get_ident tn) (Tn.num_elems tn) diff --git a/arrayjit/lib/cc_backend.ml b/arrayjit/lib/cc_backend.ml index e42aa95e..0d5cde2a 100644 --- a/arrayjit/lib/cc_backend.ml +++ b/arrayjit/lib/cc_backend.ml @@ -72,47 +72,37 @@ let c_compile_and_load ~f_name = result module C_syntax_config (Input : sig - val procs : (Low_level.optimized * buffer_ptr ctx_arrays option) array + val procs : Low_level.optimized array end) = struct - type nonrec buffer_ptr = buffer_ptr - let procs = Input.procs - let hardcoded_context_ptr = c_ptr_to_string let use_host_memory = use_host_memory let logs_to_stdout = false let main_kernel_prefix = "" let kernel_prep_line = "" - - let include_lines = - [ "#include "; "#include "; "#include "; "#include " ] - + let includes = [ ""; ""; ""; "" ] let typ_of_prec = Ops.c_typ_of_prec let binop_syntax = Ops.binop_c_syntax let unop_syntax = Ops.unop_c_syntax let convert_precision = Ops.c_convert_precision end -let%diagn_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_level.optimized) = +let%diagn_sexp compile ~(name : string) bindings (lowered : Low_level.optimized) = let module Syntax = C_syntax.C_syntax (C_syntax_config (struct - let procs = [| (lowered, opt_ctx_arrays) |] + let procs = [| lowered |] end)) in (* FIXME: do we really want all of them, or only the used ones? *) let idx_params = Indexing.bound_symbols bindings in let pp_file = Utils.pp_file ~base_name:name ~extension:".c" in - let is_global = Syntax.compile_globals pp_file.ppf in - let params = Syntax.compile_proc ~name pp_file.ppf idx_params ~is_global lowered in + Syntax.print_includes pp_file.ppf; + let params = Syntax.compile_proc ~name pp_file.ppf idx_params lowered in pp_file.finalize (); let result = c_compile_and_load ~f_name:pp_file.f_name in { result; params; bindings; name } -let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings - (lowereds : Low_level.optimized option array) = +let%diagn_sexp compile_batch ~names bindings (lowereds : Low_level.optimized option array) = let module Syntax = C_syntax.C_syntax (C_syntax_config (struct - let procs = - Array.filter_mapi lowereds ~f:(fun i -> - Option.map ~f:(fun lowereds -> - (lowereds, Option.(map opt_ctx_arrays ~f:(fun ctx_arrays -> value_exn ctx_arrays.(i)))))) + let procs = Array.filter_opt lowereds end)) in (* FIXME: do we really want all of them, or only the used ones? *) let idx_params = Indexing.bound_symbols bindings in @@ -122,11 +112,11 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings @@ common_prefix (Array.to_list @@ Array.concat_map ~f:Option.to_array names)) in let pp_file = Utils.pp_file ~base_name ~extension:".c" in - let is_global = Syntax.compile_globals pp_file.ppf in + Syntax.print_includes pp_file.ppf; let params = Array.mapi lowereds ~f:(fun i lowered -> Option.map2 names.(i) lowered ~f:(fun name lowered -> - Syntax.compile_proc ~name pp_file.ppf idx_params ~is_global lowered)) + Syntax.compile_proc ~name pp_file.ppf idx_params lowered)) in pp_file.finalize (); let result = c_compile_and_load ~f_name:pp_file.f_name in diff --git a/arrayjit/lib/cuda_backend.cudajit.ml b/arrayjit/lib/cuda_backend.cudajit.ml index d47e1379..8dd80a02 100644 --- a/arrayjit/lib/cuda_backend.cudajit.ml +++ b/arrayjit/lib/cuda_backend.cudajit.ml @@ -267,13 +267,12 @@ let%diagn2_sexp cuda_to_ptx ~name cu_src = ptx module C_syntax_config (Input : sig - val procs : (Low_level.optimized * ctx_arrays option) array + val procs : Low_level.optimized array end) = struct type nonrec buffer_ptr = buffer_ptr [@@deriving sexp_of] let procs = Input.procs - let hardcoded_context_ptr = None let use_host_memory = use_host_memory let logs_to_stdout = true let main_kernel_prefix = "extern \"C\" __global__" @@ -281,7 +280,7 @@ struct let kernel_prep_line = "/* FIXME: single-threaded for now. */if (threadIdx.x != 0 || blockIdx.x != 0) { return; }" - let include_lines = [ "#include " ] + let includes = [ "" ] let typ_of_prec = function | Ops.Byte_prec _ -> "unsigned char" @@ -341,31 +340,31 @@ let compile ~name bindings ({ Low_level.traced_store; _ } as lowered) = (* TODO: The following link seems to claim it's better to expand into loops than use memset. https://stackoverflow.com/questions/23712558/how-do-i-best-initialize-a-local-memory-array-to-0 *) let module Syntax = C_syntax.C_syntax (C_syntax_config (struct - let procs = [| (lowered, None) |] + let procs = [| lowered |] end)) in let idx_params = Indexing.bound_symbols bindings in let b = Buffer.create 4096 in let ppf = Stdlib.Format.formatter_of_buffer b in if Utils.debug_log_from_routines () then Stdlib.Format.fprintf ppf "@,__device__ int printf (const char * format, ... );@,"; - let is_global = Syntax.compile_globals ppf in - let params = Syntax.compile_proc ~name ~is_global ppf idx_params lowered in + Syntax.print_includes ppf; + let params = Syntax.compile_proc ~name ppf idx_params lowered in let ptx = cuda_to_ptx ~name @@ Buffer.contents b in { traced_store; ptx; params; bindings; name } let compile_batch ~names bindings lowereds = let module Syntax = C_syntax.C_syntax (C_syntax_config (struct - let procs = Array.filter_map lowereds ~f:(Option.map ~f:(fun lowereds -> (lowereds, None))) + let procs = Array.filter_opt lowereds end)) in let idx_params = Indexing.bound_symbols bindings in let b = Buffer.create 4096 in let ppf = Stdlib.Format.formatter_of_buffer b in - let is_global = Syntax.compile_globals ppf in + Syntax.print_includes ppf; let params_and_names = Array.map2_exn names lowereds ~f: (Option.map2 ~f:(fun name lowered -> - (Syntax.compile_proc ~name ~is_global ppf idx_params lowered, name))) + (Syntax.compile_proc ~name ppf idx_params lowered, name))) in let name : string = String.( diff --git a/arrayjit/lib/gcc_backend.gccjit.ml b/arrayjit/lib/gcc_backend.gccjit.ml index 57ffbf0f..3302e75c 100644 --- a/arrayjit/lib/gcc_backend.gccjit.ml +++ b/arrayjit/lib/gcc_backend.gccjit.ml @@ -98,8 +98,8 @@ let zero_out ctx block node = let get_c_ptr ctx num_typ ptr = Gccjit.(RValue.ptr ctx (Type.pointer num_typ) ptr) -let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~opt_ctx_arrays ~param_ptrs - initializations (tn : Tn.t) = +let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~param_ptrs initializations + (tn : Tn.t) = let open Gccjit in let traced = Low_level.(get_node traced_store tn) in let dims = Lazy.force tn.dims in @@ -116,14 +116,12 @@ let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~opt_ctx_arrays let hosted = Tn.is_hosted_force tn 344 in let in_ctx = Tn.is_in_context ~use_host_memory tn in let ptr = - match (in_ctx, opt_ctx_arrays, hosted) with - | Some true, Some ctx_arrays, _ -> - Lazy.from_val @@ get_c_ptr ctx num_typ @@ Map.find_exn ctx_arrays tn - | (Some true | None), None, _ -> + match (in_ctx, hosted) with + | Some true, _ -> let p = Param.create ctx ptr_typ ident in param_ptrs := (p, Param_ptr tn) :: !param_ptrs; Lazy.from_val (RValue.param p) - | (Some false | None), _, true -> ( + | (Some false | None), true -> ( let addr arr = Lazy.from_val @@ get_c_ptr ctx num_typ @@ Ctypes.bigarray_start Ctypes_static.Genarray arr in @@ -133,7 +131,7 @@ let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~opt_ctx_arrays | Some (Single_nd arr) -> addr arr | Some (Double_nd arr) -> addr arr | None -> assert false) - | (Some false | None), _, false -> + | (Some false | None), false -> let arr_typ = Type.array ctx num_typ size_in_elems in let v = ref None in let initialize _init_block func = v := Some (Function.local func arr_typ ident) in @@ -500,7 +498,7 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node; loop_proc ~toplevel:true ~name ~env body; !current_block -let%diagn_sexp compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident +let%diagn_sexp compile_proc ~name ctx bindings ~get_ident Low_level.{ traced_store; llc = proc; merge_node } = let open Gccjit in let c_index = Type.get ctx Type.Int in @@ -536,7 +534,7 @@ let%diagn_sexp compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident let data = prepare_node ~debug_log_zero_out:(debug_log_zero_out ctx log_functions get_ident) - ~get_ident ctx traced_store ~opt_ctx_arrays ~param_ptrs initializations tn + ~get_ident ctx traced_store ~param_ptrs initializations tn in Hashtbl.add_exn nodes ~key:tn ~data); let params : (gccjit_param * param_source) list = !param_ptrs in @@ -590,7 +588,7 @@ let%diagn_sexp compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident Block.return_void after_proc; (ctx_info, params) -let compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_level.optimized) = +let compile ~(name : string) bindings (lowered : Low_level.optimized) = let get_ident = Low_level.get_ident_within_code ~no_dots:true [| lowered.llc |] in let open Gccjit in if Option.is_none !root_ctx then initialize (); @@ -599,7 +597,7 @@ let compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_level.optim (* if Utils.settings.with_debug && Utils.settings.output_debug_files_in_build_directory then ( Context.set_option ctx Context.Keep_intermediates true; Context.set_option ctx Context.Dump_everything true); *) - let info, params = compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident lowered in + let info, params = compile_proc ~name ctx bindings ~get_ident lowered in (if Utils.settings.output_debug_files_in_build_directory then let f_name = Utils.build_file @@ name ^ "-gccjit-debug.c" in Context.dump_to_file ctx ~update_locs:true f_name); @@ -608,7 +606,7 @@ let compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_level.optim Context.release ctx; { info; result; bindings; name; params = List.map ~f:snd params } -let%diagn_sexp compile_batch ~(names : string option array) ~opt_ctx_arrays bindings +let%diagn_sexp compile_batch ~(names : string option array) bindings (lowereds : Low_level.optimized option array) = let get_ident = Low_level.get_ident_within_code ~no_dots:true @@ -623,10 +621,9 @@ let%diagn_sexp compile_batch ~(names : string option array) ~opt_ctx_arrays bind Context.Dump_everything true); *) let funcs = Array.mapi lowereds ~f:(fun i lowered -> - let opt_ctx_arrays = Option.(join @@ map opt_ctx_arrays ~f:(fun arrs -> arrs.(i))) in match (names.(i), lowered) with | Some name, Some lowered -> - let info, params = compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident lowered in + let info, params = compile_proc ~name ctx bindings ~get_ident lowered in Some (info, params) | _ -> None) in diff --git a/arrayjit/lib/no_device_backend_missing.ml b/arrayjit/lib/no_device_backend_missing.ml index 172b0c3b..f5f6f7b9 100644 --- a/arrayjit/lib/no_device_backend_missing.ml +++ b/arrayjit/lib/no_device_backend_missing.ml @@ -9,10 +9,10 @@ type procedure let sexp_of_procedure _procedure = failwith "Backend missing -- install the corresponding library" -let compile ~name:_ ~opt_ctx_arrays:_ _unit_bindings _optimized = +let compile ~name:_ _unit_bindings _optimized = failwith "Backend missing -- install the corresponding library" -let compile_batch ~names:_ ~opt_ctx_arrays:_ _unit_bindings _optimizeds = +let compile_batch ~names:_ _unit_bindings _optimizeds = failwith "Backend missing -- install the corresponding library" let link_compiled ~merge_buffer:_ ~runner_label:_ _ctx_arrays _procedure =