Skip to content

Commit

Permalink
Bump ocamlformat version, reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Dec 8, 2024
1 parent e85a09d commit 93b427d
Show file tree
Hide file tree
Showing 22 changed files with 171 additions and 104 deletions.
2 changes: 1 addition & 1 deletion .ocamlformat
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
profile = default
version = 0.26.2
version = 0.27.0
margin = 100
parse-docstrings = true
wrap-comments = true
10 changes: 6 additions & 4 deletions arrayjit/lib/backend_impl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ end

module Device
(Device_types : Device_types)
(Alloc_buffer : Alloc_buffer
with type buffer_ptr := Device_types.buffer_ptr
and type stream := Device_types.stream) =
(Alloc_buffer :
Alloc_buffer
with type buffer_ptr := Device_types.buffer_ptr
and type stream := Device_types.stream) =
struct
include Device_types
include Alloc_buffer
Expand Down Expand Up @@ -138,7 +139,8 @@ module type Backend_impl_common = sig
[use_host_memory] can only be true on unified memory devices, like CPU and Apple Metal. *)
end

(** An interface to adding schedulers for stream-agnostic (typically CPU) backend implementations. *)
(** An interface to adding schedulers for stream-agnostic (typically CPU) backend implementations.
*)
module type For_add_scheduler = sig
include Backend_any_common

Expand Down
6 changes: 4 additions & 2 deletions arrayjit/lib/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ module type Backend_device_common = sig
(** Global debug information; backend-specific and might evolve independently on the backends. *)

val get_debug_info : stream -> Sexp.t
(** Per-stream debug information; backend-specific and might evolve independently on the backends *)
(** Per-stream debug information; backend-specific and might evolve independently on the backends
*)

val await : stream -> unit
(** Blocks till the stream becomes idle, i.e. synchronizes the stream. *)
Expand Down Expand Up @@ -329,7 +330,8 @@ module type Backend = sig
include Backend_device_common with type buffer_ptr := buffer_ptr

val link : context -> code -> context routine
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)
(** Returns the routine for the code's procedure, in a new context derived from the given context.
*)

val link_batch : context -> code_batch -> context * context routine option array
(** Returns the routines for the procedures included in the code batch. The returned context is
Expand Down
4 changes: 2 additions & 2 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
in
{ from_prior_context; name; lowered; code; expected_merge_node = lowered.Low_level.merge_node }

let%debug3_sexp compile_batch ?shared ?names ?occupancy bindings (comps : Assignments.comp array) :
code_batch =
let%debug3_sexp compile_batch ?shared ?names ?occupancy bindings (comps : Assignments.comp array)
: code_batch =
let names, lowereds =
lower_batch_assignments ?names ?occupancy bindings
@@ Array.map comps ~f:(fun c -> c.Assignments.asgns)
Expand Down
10 changes: 5 additions & 5 deletions arrayjit/lib/cc_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro
[%log_level
0;
let rec link :
'a 'b 'idcs.
'idcs Indexing.bindings ->
param_source list ->
('a -> 'b) Ctypes.fn ->
('a -> 'b, 'idcs, 'p1, 'p2) Indexing.variadic =
'a 'b 'idcs.
'idcs Indexing.bindings ->
param_source list ->
('a -> 'b) Ctypes.fn ->
('a -> 'b, 'idcs, 'p1, 'p2) Indexing.variadic =
fun (type a b idcs) (binds : idcs Indexing.bindings) params (cs : (a -> b) Ctypes.fn) ->
match (binds, params) with
| Empty, [] -> Indexing.Result (Foreign.foreign ~from:code.result.lib name cs)
Expand Down
2 changes: 1 addition & 1 deletion arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ let to_host ~src_ptr ~src hosted =

let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
(* Stdio.printf "run: device_to_device %s dst backend:0:%d src backend:0:%d\n" (Tn.debug_name tn)
dst.stream.stream_id src.stream.stream_id; *)
dst.stream.stream_id src.stream.stream_id; *)
let dev = dst.stream.device in
let same_device = dev.ordinal = src.stream.device.ordinal in
let size_in_bytes = Tn.size_in_bytes tn in
Expand Down
10 changes: 5 additions & 5 deletions arrayjit/lib/gcc_backend.gccjit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -654,11 +654,11 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro
[%log_level
0;
let rec link :
'a 'b 'idcs.
'idcs Indexing.bindings ->
param_source list ->
('a -> 'b) Ctypes.fn ->
('a -> 'b, 'idcs, 'p1, 'p2) Indexing.variadic =
'a 'b 'idcs.
'idcs Indexing.bindings ->
param_source list ->
('a -> 'b) Ctypes.fn ->
('a -> 'b, 'idcs, 'p1, 'p2) Indexing.variadic =
fun (type a b idcs) (binds : idcs Indexing.bindings) params (cs : (a -> b) Ctypes.fn) ->
match (binds, params) with
| Empty, [] -> Indexing.Result (Gccjit.Result.code code.result name cs)
Expand Down
3 changes: 2 additions & 1 deletion arrayjit/lib/indexing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ type projections = {
(** The product space dimensions that an operation should parallelize (map-reduce) over. *)
lhs_dims : int array; (** The dimensions of the LHS array. *)
rhs_dims : int array array;
(** The dimensions of the RHS arrays, needed for deriving projections from other projections. *)
(** The dimensions of the RHS arrays, needed for deriving projections from other projections.
*)
product_iterators : symbol array;
(** The product space iterators (concatentation of the relevant batch, output, input axes) for
iterating over the [product_space] axes, where same axes are at same array indices. *)
Expand Down
8 changes: 4 additions & 4 deletions arrayjit/lib/low_level.ml
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ let%diagn_sexp check_and_store_virtual traced static_indices top_llc =
if not @@ Set.mem env_dom s then
[%log2
"INFO: Inlining candidate has an escaping variable",
(_idx : Indexing.axis_index),
(top_llc : t)];
(_idx : Indexing.axis_index),
(top_llc : t)];
raise @@ Non_virtual 7
| _ -> ());
loop_float ~env_dom llv
Expand All @@ -311,8 +311,8 @@ let%diagn_sexp check_and_store_virtual traced static_indices top_llc =
if not @@ Set.mem env_dom s then (
[%log2
"Inlining candidate has an escaping variable",
(s : Indexing.symbol),
(top_llc : t)];
(s : Indexing.symbol),
(top_llc : t)];
raise @@ Non_virtual 9)
| _ -> ())
| Local_scope { body; _ } -> loop_proc ~env_dom body
Expand Down
25 changes: 13 additions & 12 deletions arrayjit/lib/ndarray.ml
Original file line number Diff line number Diff line change
Expand Up @@ -426,21 +426,22 @@ let log_debug_info ~from_log_level:_level _nd =
"Ndarray " ^ Sexp.to_string_hum (sexp_of_t _nd);
[%log
"value-at-0:",
(get_as_float _nd (Array.map (dims _nd) ~f:(fun _ -> 0)) : float),
"has nan:",
(fold_as_float _nd ~init:false ~f:(fun has_nan _ v -> has_nan || Float.is_nan v) : bool),
"has +inf:",
(fold_as_float _nd ~init:false ~f:(fun has_inf _ v -> has_inf || Float.(v = infinity))
: bool),
"has -inf:",
(fold_as_float _nd ~init:false ~f:(fun has_neg_inf _ v ->
has_neg_inf || Float.(v = neg_infinity))
: bool)]]]]
(get_as_float _nd (Array.map (dims _nd) ~f:(fun _ -> 0)) : float),
"has nan:",
(fold_as_float _nd ~init:false ~f:(fun has_nan _ v -> has_nan || Float.is_nan v) : bool),
"has +inf:",
(fold_as_float _nd ~init:false ~f:(fun has_inf _ v -> has_inf || Float.(v = infinity))
: bool),
"has -inf:",
(fold_as_float _nd ~init:false ~f:(fun has_neg_inf _ v ->
has_neg_inf || Float.(v = neg_infinity))
: bool)]]]]

let concise_float ~prec v =
Printf.sprintf "%.*e" prec v
|> (* The C99 standard requires at least two digits for the exponent, but the leading zero is a
waste of space. *)
|>
(* The C99 standard requires at least two digits for the exponent, but the leading zero is a waste
of space. *)
String.substr_replace_first ~pattern:"e+0" ~with_:"e+"
|> String.substr_replace_first ~pattern:"e-0" ~with_:"e-"

Expand Down
4 changes: 2 additions & 2 deletions arrayjit/lib/ppx_helper.ml
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ let ndarray_constant expr =
let result = loop_values 0 [] expr in
let values = { expr with pexp_desc = Pexp_array (List.rev result) } in
let batch_dims, output_dims, input_dims =
Array.fold dims_spec ~init:([], [], []) ~f:(fun (batch_dims, output_dims, input_dims) ->
function
Array.fold dims_spec ~init:([], [], [])
~f:(fun (batch_dims, output_dims, input_dims) -> function
| `Input_dims dim -> (batch_dims, output_dims, eint ~loc dim :: input_dims)
| `Output_dims dim -> (batch_dims, eint ~loc dim :: output_dims, input_dims)
| `Batch_dims dim -> (eint ~loc dim :: batch_dims, output_dims, input_dims))
Expand Down
6 changes: 4 additions & 2 deletions bin/einsum_trivia.ml
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ let () =
Utils.settings.output_debug_files_in_build_directory <- true;
Utils.settings.debug_log_from_routines <- true;
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
let backend = (module Backend : Backend
let backend =
(module Backend : Backend
with type buffer_ptr = Backend.buffer_ptr
and type dev = Backend.dev
and type runner = Backend.runner
and type event = Backend.event) in
and type event = Backend.event)
in
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
let ctx = Backend.make_context stream in
Rand.init 0;
Expand Down
18 changes: 12 additions & 6 deletions bin/hello_world.ml
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ let hello3 () =

let hello4 () =
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
let backend = (module Backend : Backend
let backend =
(module Backend : Backend
with type buffer_ptr = Backend.buffer_ptr
and type dev = Backend.dev
and type runner = Backend.runner
and type event = Backend.event) in
and type event = Backend.event)
in
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
let ctx = Backend.make_context stream in
Rand.init 0;
Expand Down Expand Up @@ -105,11 +107,13 @@ let hello5 () =
Utils.settings.output_debug_files_in_build_directory <- true;
Utils.settings.debug_log_from_routines <- true;
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
let backend = (module Backend : Backend
let backend =
(module Backend : Backend
with type buffer_ptr = Backend.buffer_ptr
and type dev = Backend.dev
and type runner = Backend.runner
and type event = Backend.event) in
and type event = Backend.event)
in
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
let ctx = Backend.make_context stream in
Rand.init 0;
Expand All @@ -124,11 +128,13 @@ let hello6 () =
Utils.settings.output_debug_files_in_build_directory <- true;
Utils.settings.debug_log_from_routines <- true;
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
let backend = (module Backend : Backend
let backend =
(module Backend : Backend
with type buffer_ptr = Backend.buffer_ptr
and type dev = Backend.dev
and type runner = Backend.runner
and type event = Backend.event) in
and type event = Backend.event)
in
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
let ctx = Backend.make_context stream in
Rand.init 0;
Expand Down
2 changes: 1 addition & 1 deletion dune-project
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
(>= 0.23))
; opam 2.2.0 has with-dev-setup. Is it supported, what's the syntax?
; (ocamlformat
; (>= 0.26.2)
; (>= 0.27.0)
; :with-dev-setup)
printbox
printbox-text
Expand Down
3 changes: 2 additions & 1 deletion lib/row.ml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ type dim_constraint = Unconstrained_dim | At_least_dim of int
type row_constraint =
| Unconstrained
| Total_elems of { nominator : int; divided_by : Set.M(Dim_var).t }
(** The row or remainder of a row, inclusive of the further row spec, has this many elements. *)
(** The row or remainder of a row, inclusive of the further row spec, has this many elements.
*)
[@@deriving equal, hash, compare, sexp, variants]

(** An entry implements inequalities [cur >= v >= subr] and/or an equality [v = solved]. [cur] and
Expand Down
6 changes: 4 additions & 2 deletions lib/row.mli
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ type row_var [@@deriving sexp, compare, equal, hash]

val get_row_var : unit -> row_var

(** A bcast specifies how axes of a single kind in a shape (i.e. the row) can adapt to other shapes. *)
(** A bcast specifies how axes of a single kind in a shape (i.e. the row) can adapt to other shapes.
*)
type bcast =
| Row_var of { v : row_var; beg_dims : dim list }
(** The row can be inferred to have more axes. *)
Expand Down Expand Up @@ -58,7 +59,8 @@ type dim_constraint = Unconstrained_dim | At_least_dim of int
type row_constraint =
| Unconstrained
| Total_elems of { nominator : int; divided_by : dim_var_set }
(** The row or remainder of a row, inclusive of the further row spec, has this many elements. *)
(** The row or remainder of a row, inclusive of the further row spec, has this many elements.
*)
[@@deriving equal, hash, compare, sexp, variants]

(** An entry implements inequalities [cur >= v >= subr] and/or an equality [v = solved]. [cur] and
Expand Down
3 changes: 2 additions & 1 deletion lib/shape.ml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ let _get_local_debug_runtime = Arrayjit.Utils._get_local_debug_runtime
Note the following inconsistency due to differing conventions in function notation and matrix
notation: for label specifications and einsum notation, we write "batch|inputs->outputs", but
when we convert a shape to an [Ndarray] index we do it in the order [[batch; outputs; inputs]]. *)
when we convert a shape to an [Ndarray] index we do it in the order [[batch; outputs; inputs]].
*)
module AxisKey = struct
module T = struct
type kind = [ `Batch | `Input | `Output ] [@@deriving equal, compare, sexp, hash]
Expand Down
3 changes: 1 addition & 2 deletions lib/tensor.ml
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,7 @@ let number ?(label = []) ?axis_label ?(grad_spec = Prohibit_grad) c =
in
Tn.update_memory_mode t.value Effectively_constant 24;
Arrayjit.Ops.(
if Tn.exceeds_fp16_cutoff t.value c then
Tn.update_prec ~only_if:is_up_to_fp16 t.value single);
if Tn.exceeds_fp16_cutoff t.value c then Tn.update_prec ~only_if:is_up_to_fp16 t.value single);
t

let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?output_dims
Expand Down
Loading

0 comments on commit 93b427d

Please sign in to comment.