Skip to content

Commit

Permalink
Get rid of hard-coded pointers: all materialized nodes are kernel par…
Browse files Browse the repository at this point in the history
…ameters
  • Loading branch information
lukstafi committed Dec 9, 2024
1 parent 606f3d2 commit 25c71e5
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 17 deletions.
4 changes: 2 additions & 2 deletions arrayjit/lib/anatomy_of_a_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

- [The Anatomy of an OCANNL Backend](#the-anatomy-of-an-ocannl-backend)
- [Design around compiling and running code, backend interfaces](#design-around-compiling-and-running-code-backend-interfaces)
- [Shared relocatable compilation, batch compilation](#shared-relocatable-compilation-batch-compilation)
- [Batch compilation; in the future: lazy and cached compilation artifacts](#batch-compilation-in-the-future-lazy-and-cached-compilation-artifacts)
- [Tensor nodes, arrays, memory properties](#tensor-nodes-arrays-memory-properties)
- [Typical details of a backend implementation](#typical-details-of-a-backend-implementation)
- [Conditionally emitting the tracing debugger code](#conditionally-emitting-the-tracing-debugger-code)
Expand Down Expand Up @@ -125,7 +125,7 @@ Conventionally, the compilation implementation is split into three functions / l
- On GPU-like backends, we cannot load the code at compile time. For example, the CUDA driver API function `cuModuleLoadDataEx` loads the module into _the current context_, which is device-specific, so it must be called from within `link` or `link_batch`.
- GPU-like backends necessitate distinguishing between `link` and `link_batch`, to prevent the same code from being loaded as multiple modules.

The `C_syntax` functor returns the `compile_proc` function for use by `compile` and `compile_batch` of the backends.
The `C_syntax` functor returns the `compile_proc` function for use by `compile` and `compile_batch` of the backends. For simplicity, `C_syntax` passes all materialized nodes by parameters even for backends that use some nodes directly from the host rather than from the device / from context.

### Conditionally emitting the tracing debugger code

Expand Down
21 changes: 10 additions & 11 deletions arrayjit/lib/c_syntax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ struct
(* let compute_array_offset ~idcs ~dims = Array.fold2_exn idcs dims ~init:0 ~f:(fun offset idx dim
-> idx + (offset * dim)) *)
let print_includes ppf =
Stdlib.Format.(fprintf ppf {|@[<v 0>%a@,|} (pp_print_list pp_include) B.includes)
Stdlib.Format.(fprintf ppf {|@[<v 0>%a@,@,|} (pp_print_list pp_include) B.includes)
let compile_main ~traced_store ppf llc : unit =
let open Stdlib.Format in
Expand Down Expand Up @@ -266,20 +266,19 @@ struct
List.rev
@@ Hashtbl.fold traced_store ~init:[] ~f:(fun ~key:tn ~data:_ params ->
(* A rough approximation to the type Gccjit_backend.mem_properties. *)
let backend_info =
Sexp.Atom
(if Tn.is_virtual_force tn 334 then "Virt"
else
match in_ctx tn with
| Some true -> "Ctx"
| Some false -> "Local"
| None -> "Unk")
let backend_info, is_param =
if Tn.is_virtual_force tn 334 then ("Virt", false)
else if Option.value ~default:false @@ in_ctx tn then ("Ctx", true)
else if Tn.is_materialized_force tn 335 then ("Global or ctx", true)
else if Tn.known_not_materialized tn then ("Local", false)
else assert false
in
let backend_info = Sexp.Atom backend_info in
if not @@ Utils.sexp_mem ~elem:backend_info tn.backend_info then
tn.backend_info <- Utils.sexp_append ~elem:backend_info tn.backend_info;
(* We often don't know ahead of linking with relevant contexts what the stream sharing
mode of the node will become. Conservatively, use passing as argument. *)
if Option.value ~default:true (in_ctx tn) then
if is_param then
(B.typ_of_prec (Lazy.force tn.Tn.prec) ^ " *" ^ get_ident tn, Param_ptr tn) :: params
else params)
in
Expand Down Expand Up @@ -345,7 +344,7 @@ struct
params);
fprintf ppf "/* Local declarations and initialization. */@ ";
Hashtbl.iteri traced_store ~f:(fun ~key:tn ~data:node ->
if not (Tn.is_virtual_force tn 333 || Option.value ~default:true (in_ctx tn)) then
if not (Tn.is_virtual_force tn 333 || Tn.is_materialized_force tn 336) then
fprintf ppf "%s %s[%d]%s;@ "
(B.typ_of_prec @@ Lazy.force tn.prec)
(get_ident tn) (Tn.num_elems tn)
Expand Down
6 changes: 4 additions & 2 deletions arrayjit/lib/tnode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -478,21 +478,23 @@ let get_style ?(arg_name = "ll_ident_style") ?(no_dots = false) () =
invalid_arg @@ "Wrong " ^ arg_name ^ ", must be one of: heuristic, name_and_label, name_only"

let header tn =
let debug = Utils.settings.log_level > 0 in
let mem_size =
if Lazy.is_val tn.array then
match tn.array with
| (lazy None) -> "<not-hosted>"
| (lazy (Some nd)) ->
let size = Int.to_string_hum @@ Nd.size_in_bytes nd in
if Utils.settings.log_level > 0 then size ^ " @ " ^ Nd.ptr_to_string_hum nd else size
if debug then size ^ " @ " ^ Nd.ptr_to_string_hum nd else size
else "<not-in-yet>"
in
let repeating_nograd_idents = Hashtbl.create ~size:1 (module String) in
let repeating_grad_idents = Hashtbl.create ~size:1 (module String) in
[%string
{|%{id tn} %{label tn} as %{
styled_ident ~repeating_nograd_idents ~repeating_grad_idents (`Heuristic_ocannl `Dot_grad) tn
}: %{debug_memory_mode tn.memory_mode}; %{dims_to_string tn}; mem in bytes: %{mem_size}|}]
}: %{debug_memory_mode tn.memory_mode}; %{dims_to_string tn}; mem in bytes: %{mem_size}%{
if debug then "; debug: " ^ Sexp.to_string_hum tn.backend_info else ""}|}]

module Registry = Stdlib.Weak.Make (struct
type nonrec t = t
Expand Down
4 changes: 2 additions & 2 deletions bin/micrograd_demo.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ module Debug_runtime = Utils.Debug_runtime

let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
Rand.init 0;
Utils.enable_runtime_debug ();
(* Utils.enable_runtime_debug (); *)
(* Utils.settings.debug_log_from_routines <- true; *)
let hid_dim = 16 in
let len = 300 in
Expand Down Expand Up @@ -183,7 +183,7 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
PrintBox_text.output Stdio.stdout plot_lr

let () = experiment 4 ~no_batch_shape_inference:true ~use_builtin_weight_decay:true ()
let () = experiment 4 ~no_batch_shape_inference:false ~use_builtin_weight_decay:false ()
let _suspended () = experiment 4 ~no_batch_shape_inference:false ~use_builtin_weight_decay:false ()

let _suspended () =
for seed = 0 to 19 do
Expand Down

0 comments on commit 25c71e5

Please sign in to comment.