Skip to content

Commit

Permalink
Memorize size_in_bytes inside Tnode.t
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Dec 16, 2024
1 parent 186a2d3 commit 77b3395
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 14 deletions.
3 changes: 1 addition & 2 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ module Add_device

let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
let s = dst.stream in
let size_in_bytes = Tnode.size_in_bytes tn in
let size_in_bytes = Lazy.force tn.Tnode.size_in_bytes in
let work =
(* TODO: log the operation if [Utils.settings.with_log_level > 1]. *)
match (into_merge_buffer, dst_ptr) with
Expand All @@ -280,7 +280,6 @@ module Add_device
| Streaming_for _, _ -> fun () -> s.merge_buffer := Some { ptr = src_ptr; size_in_bytes }
| Copy, _ ->
fun () ->
let size_in_bytes = Tnode.size_in_bytes tn in
let allocated_capacity =
match s.allocated_buffer with None -> 0 | Some buf -> buf.size_in_bytes
in
Expand Down
3 changes: 2 additions & 1 deletion arrayjit/lib/c_syntax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ struct
let in_ctx tn = B.(Tn.is_in_context_force ~use_host_memory tn 46)

let pp_zero_out ppf tn =
Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " (get_ident tn) @@ Tn.size_in_bytes tn
Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " (get_ident tn)
@@ Lazy.force tn.size_in_bytes

let pp_include ppf s = Stdlib.Format.fprintf ppf "#include %s" s

Expand Down
3 changes: 1 addition & 2 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
dst.stream.stream_id src.stream.stream_id; *)
let dev = dst.stream.device in
let same_device = dev.ordinal = src.stream.device.ordinal in
let size_in_bytes = Tn.size_in_bytes tn in
let size_in_bytes = Lazy.force tn.Tn.size_in_bytes in
let memcpy ~dst_ptr =
if same_device && Cu.Deviceptr.equal dst_ptr src_ptr then ()
else if same_device then
Expand All @@ -217,7 +217,6 @@ let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
dst.stream.merge_buffer := Some { ptr = src_ptr; size_in_bytes }
| Copy, _ ->
set_ctx @@ ctx_of dst;
let size_in_bytes = Tn.size_in_bytes tn in
opt_alloc_merge_buffer ~size_in_bytes dev.dev dst.stream;
let buffer = Option.value_exn ~here:[%here] !(dst.stream.merge_buffer) in
memcpy ~dst_ptr:buffer.ptr
Expand Down
2 changes: 1 addition & 1 deletion arrayjit/lib/gcc_backend.gccjit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ let zero_out ctx block node =
[
Lazy.force node.ptr;
RValue.zero ctx c_int;
RValue.int ctx c_index @@ Tn.size_in_bytes node.tn;
RValue.int ctx c_index @@ Lazy.force node.tn.size_in_bytes;
]

let get_c_ptr ctx num_typ ptr = Gccjit.(RValue.ptr ctx (Type.pointer num_typ) ptr)
Expand Down
16 changes: 8 additions & 8 deletions arrayjit/lib/tnode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,14 @@ type memory_mode =
optional [array] of {!t}). *)
[@@deriving sexp, compare, equal]

type delayed_prec =
| Not_specified
| Default_spec of (Ops.prec Lazy.t[@sexp.opaque])
| Specified of Ops.prec
type delayed_prec = Not_specified | Default_spec of Ops.prec Lazy.t | Specified of Ops.prec
[@@deriving sexp, equal]

type t = {
array : (Nd.t option Lazy.t[@sexp.opaque]);
prec : (Ops.prec Lazy.t[@sexp.opaque]);
dims : (int array Lazy.t[@sexp.opaque]);
array : Nd.t option Lazy.t;
prec : Ops.prec Lazy.t;
dims : int array Lazy.t;
size_in_bytes : int Lazy.t;
id : int;
label : string list;
(** Display information. It is better if the last element of the list is the most narrow or
Expand All @@ -90,7 +88,6 @@ let num_elems tn =
let dims = Lazy.force tn.dims in
if Array.is_empty dims then 0 else Array.reduce_exn dims ~f:( * )

let size_in_bytes tn = num_elems tn * Ops.prec_in_bytes (Lazy.force tn.prec)
let id { id; _ } = "n" ^ Int.to_string id
let label a = String.concat ~sep:"_" a.label
let is_alphanum_ = String.for_all ~f:(fun c -> Char.equal c '_' || Char.is_alphanum c)
Expand Down Expand Up @@ -512,6 +509,7 @@ let create ?default_prec ~id ~label ~dims init_op =
| Specified prec | Default_spec (lazy prec) -> prec
| Not_specified ->
raise @@ Utils.User_error "Tnode.update_prec: precision is not specified yet")
and size_in_bytes = lazy (num_elems tn * Ops.prec_in_bytes (Lazy.force tn.prec))
and tn =
let delayed_prec_unsafe =
match default_prec with None -> Not_specified | Some prec -> Default_spec prec
Expand All @@ -521,6 +519,7 @@ let create ?default_prec ~id ~label ~dims init_op =
delayed_prec_unsafe;
prec;
dims;
size_in_bytes;
id;
label;
memory_mode = None;
Expand All @@ -541,6 +540,7 @@ let find =
prec = lazy Ops.single;
delayed_prec_unsafe = Specified Ops.single;
dims = lazy [||];
size_in_bytes = lazy 0;
id = -1;
label = [];
memory_mode = None;
Expand Down

0 comments on commit 77b3395

Please sign in to comment.