From 25c71e532a87983175e98a4dcc0752a18560095e Mon Sep 17 00:00:00 2001 From: Lukasz Stafiniak Date: Mon, 9 Dec 2024 10:04:22 +0100 Subject: [PATCH] Get rid of hard-coded pointers: all materialized nodes are kernel parameters --- arrayjit/lib/anatomy_of_a_backend.md | 4 ++-- arrayjit/lib/c_syntax.ml | 21 ++++++++++----------- arrayjit/lib/tnode.ml | 6 ++++-- bin/micrograd_demo.ml | 4 ++-- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/arrayjit/lib/anatomy_of_a_backend.md b/arrayjit/lib/anatomy_of_a_backend.md index 1a97c017..defa9016 100644 --- a/arrayjit/lib/anatomy_of_a_backend.md +++ b/arrayjit/lib/anatomy_of_a_backend.md @@ -4,7 +4,7 @@ - [The Anatomy of an OCANNL Backend](#the-anatomy-of-an-ocannl-backend) - [Design around compiling and running code, backend interfaces](#design-around-compiling-and-running-code-backend-interfaces) - - [Shared relocatable compilation, batch compilation](#shared-relocatable-compilation-batch-compilation) + - [Batch compilation; in the future: lazy and cached compilation artifacts](#batch-compilation-in-the-future-lazy-and-cached-compilation-artifacts) - [Tensor nodes, arrays, memory properties](#tensor-nodes-arrays-memory-properties) - [Typical details of a backend implementation](#typical-details-of-a-backend-implementation) - [Conditionally emitting the tracing debugger code](#conditionally-emitting-the-tracing-debugger-code) @@ -125,7 +125,7 @@ Conventionally, the compilation implementation is split into three functions / l - On GPU-like backends, we cannot load the code at compile time. For example, the CUDA driver API function `cuModuleLoadDataEx` loads the module into _the current context_, which is device-specific, so it must be called from within `link` or `link_batch`. - GPU-like backends necessitate distinguishing between `link` and `link_batch`, to prevent the same code from being loaded as multiple modules. -The `C_syntax` functor returns the `compile_proc` function for use by `compile` and `compile_batch` of the backends. +The `C_syntax` functor returns the `compile_proc` function for use by `compile` and `compile_batch` of the backends. For simplicity, `C_syntax` passes all materialized nodes by parameters even for backends that use some nodes directly from the host rather than from the device / from context. ### Conditionally emitting the tracing debugger code diff --git a/arrayjit/lib/c_syntax.ml b/arrayjit/lib/c_syntax.ml index a7a418f8..9bbc0ad1 100644 --- a/arrayjit/lib/c_syntax.ml +++ b/arrayjit/lib/c_syntax.ml @@ -61,7 +61,7 @@ struct (* let compute_array_offset ~idcs ~dims = Array.fold2_exn idcs dims ~init:0 ~f:(fun offset idx dim -> idx + (offset * dim)) *) let print_includes ppf = - Stdlib.Format.(fprintf ppf {|@[%a@,|} (pp_print_list pp_include) B.includes) + Stdlib.Format.(fprintf ppf {|@[%a@,@,|} (pp_print_list pp_include) B.includes) let compile_main ~traced_store ppf llc : unit = let open Stdlib.Format in @@ -266,20 +266,19 @@ struct List.rev @@ Hashtbl.fold traced_store ~init:[] ~f:(fun ~key:tn ~data:_ params -> (* A rough approximation to the type Gccjit_backend.mem_properties. *) - let backend_info = - Sexp.Atom - (if Tn.is_virtual_force tn 334 then "Virt" - else - match in_ctx tn with - | Some true -> "Ctx" - | Some false -> "Local" - | None -> "Unk") + let backend_info, is_param = + if Tn.is_virtual_force tn 334 then ("Virt", false) + else if Option.value ~default:false @@ in_ctx tn then ("Ctx", true) + else if Tn.is_materialized_force tn 335 then ("Global or ctx", true) + else if Tn.known_not_materialized tn then ("Local", false) + else assert false in + let backend_info = Sexp.Atom backend_info in if not @@ Utils.sexp_mem ~elem:backend_info tn.backend_info then tn.backend_info <- Utils.sexp_append ~elem:backend_info tn.backend_info; (* We often don't know ahead of linking with relevant contexts what the stream sharing mode of the node will become. Conservatively, use passing as argument. *) - if Option.value ~default:true (in_ctx tn) then + if is_param then (B.typ_of_prec (Lazy.force tn.Tn.prec) ^ " *" ^ get_ident tn, Param_ptr tn) :: params else params) in @@ -345,7 +344,7 @@ struct params); fprintf ppf "/* Local declarations and initialization. */@ "; Hashtbl.iteri traced_store ~f:(fun ~key:tn ~data:node -> - if not (Tn.is_virtual_force tn 333 || Option.value ~default:true (in_ctx tn)) then + if not (Tn.is_virtual_force tn 333 || Tn.is_materialized_force tn 336) then fprintf ppf "%s %s[%d]%s;@ " (B.typ_of_prec @@ Lazy.force tn.prec) (get_ident tn) (Tn.num_elems tn) diff --git a/arrayjit/lib/tnode.ml b/arrayjit/lib/tnode.ml index a26aead4..cc45d6ac 100644 --- a/arrayjit/lib/tnode.ml +++ b/arrayjit/lib/tnode.ml @@ -478,13 +478,14 @@ let get_style ?(arg_name = "ll_ident_style") ?(no_dots = false) () = invalid_arg @@ "Wrong " ^ arg_name ^ ", must be one of: heuristic, name_and_label, name_only" let header tn = + let debug = Utils.settings.log_level > 0 in let mem_size = if Lazy.is_val tn.array then match tn.array with | (lazy None) -> "" | (lazy (Some nd)) -> let size = Int.to_string_hum @@ Nd.size_in_bytes nd in - if Utils.settings.log_level > 0 then size ^ " @ " ^ Nd.ptr_to_string_hum nd else size + if debug then size ^ " @ " ^ Nd.ptr_to_string_hum nd else size else "" in let repeating_nograd_idents = Hashtbl.create ~size:1 (module String) in @@ -492,7 +493,8 @@ let header tn = [%string {|%{id tn} %{label tn} as %{ styled_ident ~repeating_nograd_idents ~repeating_grad_idents (`Heuristic_ocannl `Dot_grad) tn - }: %{debug_memory_mode tn.memory_mode}; %{dims_to_string tn}; mem in bytes: %{mem_size}|}] + }: %{debug_memory_mode tn.memory_mode}; %{dims_to_string tn}; mem in bytes: %{mem_size}%{ + if debug then "; debug: " ^ Sexp.to_string_hum tn.backend_info else ""}|}] module Registry = Stdlib.Weak.Make (struct type nonrec t = t diff --git a/bin/micrograd_demo.ml b/bin/micrograd_demo.ml index 6ef8c3aa..744b4f09 100644 --- a/bin/micrograd_demo.ml +++ b/bin/micrograd_demo.ml @@ -12,7 +12,7 @@ module Debug_runtime = Utils.Debug_runtime let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () = Rand.init 0; - Utils.enable_runtime_debug (); + (* Utils.enable_runtime_debug (); *) (* Utils.settings.debug_log_from_routines <- true; *) let hid_dim = 16 in let len = 300 in @@ -183,7 +183,7 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () = PrintBox_text.output Stdio.stdout plot_lr let () = experiment 4 ~no_batch_shape_inference:true ~use_builtin_weight_decay:true () -let () = experiment 4 ~no_batch_shape_inference:false ~use_builtin_weight_decay:false () +let _suspended () = experiment 4 ~no_batch_shape_inference:false ~use_builtin_weight_decay:false () let _suspended () = for seed = 0 to 19 do