Skip to content

Commit

Permalink
Automated from_host transfers
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Dec 30, 2024
1 parent 1a33588 commit 6d41b75
Show file tree
Hide file tree
Showing 11 changed files with 12 additions and 48 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Added

- Automatic transfers to host from the context that most recently updated a node.
- Automatic transfers of routine's inputs from host to routine's context if the host array modification was not yet transfered.

## Fixed

Expand Down
7 changes: 6 additions & 1 deletion arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
Tn.prepare_read
~is_done:(fun () -> Backend.is_done e)
~sync:(fun () -> Backend.sync e)
~transfer:(fun () -> assert (to_host ctx tn); Backend.await s)
~transfer:(fun () ->
assert (to_host ctx tn);
Backend.await s)
tn);
(* To be on the safe side, record events for potentially cross-stream nodes. *)
match tn with
Expand All @@ -92,6 +94,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
(* Stdio.printf "copying: %s from_host\n" (Tn.debug_name tn); *)
Backend.from_host ~dst_ptr:dst ~dst:ctx hosted;
update_writer_event ~from:`Host ctx @@ Node tn;
tn.host_modified <- false;
true
| _ -> false

Expand Down Expand Up @@ -140,6 +143,8 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
let s = r.context.stream in
let hosted_inputs = Set.filter r.inputs ~f:(fun tn -> Tn.is_hosted_force tn 47) in
let pre () =
assert (Domain.is_main_domain ());
Set.iter hosted_inputs ~f:(fun tn -> if tn.host_modified then assert (from_host r.context tn));
Set.iter r.inputs ~f:(fun tn ->
if Tn.potentially_cross_stream tn then
Option.iter (Hashtbl.find s.device.shared_writer_streams tn) ~f:(fun data ->
Expand Down
6 changes: 5 additions & 1 deletion arrayjit/lib/tnode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ type t = {
mutable code_name : string option;
mutable prepare_read : prepare option;
mutable prepare_write : prepare option;
mutable host_modified : bool;
}
[@@deriving sexp_of]

Expand Down Expand Up @@ -553,6 +554,7 @@ let create ?default_prec ~id ~label ~dims init_op =
code_name = None;
prepare_read = None;
prepare_write = None;
host_modified = true;
}
in
(* Note: if tensor nodes get non-trivial finalizers, remember to either add an is_finalized flag
Expand All @@ -576,6 +578,7 @@ let find =
code_name = None;
prepare_read = None;
prepare_write = None;
host_modified = false;
}
in
fun ~id -> Registry.find_opt registry { mock with id }
Expand All @@ -592,7 +595,8 @@ let do_read tn =

let do_write tn =
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_write;
tn.prepare_write <- None
tn.prepare_write <- None;
tn.host_modified <- true

let points_1d ?from_axis ~xdim tn =
do_read tn;
Expand Down
1 change: 0 additions & 1 deletion bin/compilation_speed.ml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ let benchmark_overhead backend () =
Train.to_routine (module Backend) init_assign_x.context IDX.empty update_f.fwd_bprop
in
Tensor.print_tree ~with_grad:true ~with_backend_info:true ~depth:9 f;
Tensor.iter_embedded f ~f:(fun a -> ignore (Backend.from_host f_routine.context a : bool));

let xs = Array.init n_data ~f:Float.(fun i -> of_int i - (of_int n_data /. 2.)) in
let open Operation.At in
Expand Down
2 changes: 0 additions & 2 deletions bin/hello_world.ml
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ let hello3 () =
let y = TDSL.O.(( + ) ~label:[ "y" ] (hey * zero_to_twenty) zero_to_twenty) in
Train.set_hosted hey.value;
let routine = Train.to_routine (module Backend) ctx IDX.empty @@ Train.forward y in
assert (Backend.from_host routine.context hey.value);
assert (Backend.from_host routine.context zero_to_twenty.value);
Tensor.print ~with_code:true ~with_grad:false `Inline zero_to_twenty;
Tensor.print ~with_code:true ~with_grad:false `Default zero_to_twenty;
Tensor.print_tree ~with_grad:false ~depth:9 zero_to_twenty;
Expand Down
3 changes: 0 additions & 3 deletions bin/micrograd_demo.ml
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
let routine =
Train.to_routine (module Backend) ctx bindings (Asgns.sequence [ update.fwd_bprop; sgd ])
in
Train.all_host_to_device (module Backend) routine.context scalar_loss;
Train.all_host_to_device (module Backend) routine.context learning_rate;
(* Stdio.print_endline "\n******** scalar_loss **********"; Tensor.print_tree ~with_id:true
~with_grad:false ~depth:9 scalar_loss; Stdio.print_endline "\n******** learning_rate
**********"; Tensor.print_tree ~with_id:true ~with_grad:false ~depth:9 learning_rate;
Expand Down Expand Up @@ -136,7 +134,6 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
Tn.set_values point.value [| x; y |];
(* For the gccjit backend, point is only on host, not on device. For cuda, this will be
needed. *)
assert (Backend.from_host result_routine.context point.value);
Train.run result_routine;
Float.(mlp_result.@[0] >= 0.)
in
Expand Down
3 changes: 0 additions & 3 deletions bin/moons_demo.ml
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ let demo () =
PrintBox_text.output Stdio.stdout plot_moons;
Stdio.print_endline "\n";

Train.all_host_to_device (module Backend) routine.context scalar_loss;
Train.all_host_to_device (module Backend) routine.context learning_rate;
let open Operation.At in
let step_ref = IDX.find_exn routine.bindings step_n in
let batch_ref = IDX.find_exn routine.bindings batch_n in
Expand Down Expand Up @@ -112,7 +110,6 @@ let demo () =
let callback (x, y) =
Tn.set_values point.value [| x; y |];
Utils.capture_stdout_logs @@ fun () ->
assert (Backend.from_host result_routine.context point.value);
Train.run result_routine;
Float.(mlp_result.@[0] >= 0.)
in
Expand Down
7 changes: 0 additions & 7 deletions bin/zero2hero_1of7.ml
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ let () =
let routine =
Train.to_routine (module Backend) (Backend.make_context stream) IDX.empty update.fwd_bprop
in
Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.from_host routine.context a : bool));
Train.run routine;
(* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
Backend.await stream; *)
Expand All @@ -176,18 +175,12 @@ let () =
@@ Train.sgd_update ~learning_rate update
in
(* learning_rate is virtual so this will not print anything. *)
Tensor.iter_embedded learning_rate ~f:(fun a ->
ignore (Backend.from_host routine.context a : bool));
Stdio.print_endline
{|
Due to how the gccjit backend works, since the parameters were constant in the grad_update
computation, they did not exist on the device before. Now they do. This would not be needed
on the cuda backend.|};
List.iter [ a.value; b.value; c.value; f.value ] ~f:(fun a ->
assert (Backend.from_host routine.context a));
Train.run routine;
(* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
Backend.await stream; *)
Stdio.print_endline
{|
Now we updated the params, but after the forward and backward passes:
Expand Down
14 changes: 0 additions & 14 deletions lib/train.ml
Original file line number Diff line number Diff line change
Expand Up @@ -269,16 +269,6 @@ let every_non_literal_on_host =
Tensor.iter_embedded ~f:(fun a ->
if Tn.mode_is_unspecified a && not (Tn.known_constant a) then set_hosted a)

(* Note: this will get nicer with modular explicits. *)
let%debug2_sexp all_host_to_device (type buffer_ptr dev runner event)
(module Backend : Backend
with type buffer_ptr = buffer_ptr
and type dev = dev
and type runner = runner
and type event = event) (context : Backend.context) =
let f tn = ignore (Backend.from_host context tn : bool) in
Tensor.iter_embedded ~f

module Lazy = Utils.Lazy

(** Performs one optimization step, potentially in parallel (if [grad_updates] are linked with
Expand Down Expand Up @@ -469,8 +459,6 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
let sgd_update = to_routine (module Backend) grad_updates.(0).context bindings sgd in
Tensor.log_debug_info ~from_log_level:2 inputs;
Tensor.log_debug_info ~from_log_level:2 outputs;
all_host_to_device (module Backend) sgd_update.context scalar_loss;
all_host_to_device (module Backend) sgd_update.context learning_rate;
let open Operation.At in
let epoch_loss = ref 0. in
let step_ref = IDX.find_exn sgd_update.bindings step_n in
Expand Down Expand Up @@ -531,7 +519,6 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
(* For the gccjit backend, infer is only on host, not on device. For cuda, this will be
needed. *)
Utils.capture_stdout_logs @@ fun () ->
assert (Backend.from_host routine.context infer.value);
run routine;
Tn.get_values model_result.value
in
Expand All @@ -558,7 +545,6 @@ let%track3_sexp forward_and_ctx ?(disable_rootness_check = false) (type buffer_p
and type event = event) ctx ?(bindings = IDX.empty) t =
let routine = Backend.(link ctx @@ compile bindings @@ forward ~disable_rootness_check t) in
if not disable_rootness_check then Tensor.remove_bprop_root t;
Tensor.iter_embedded t ~f:(fun a -> ignore (Backend.from_host routine.context a : bool));
Task.run routine.schedule;
routine.context

Expand Down
11 changes: 0 additions & 11 deletions test/micrograd_demo.ml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ let%expect_test "Micrograd README basic example" =
List.iter ~f:(Option.iter ~f:(fun diff -> Train.set_hosted diff.Tensor.grad)) [ a.diff; b.diff ];
let update = Train.grad_update g in
let step = Train.to_routine (module Backend) ctx IDX.empty update.fwd_bprop in
Tensor.iter_embedded g ~f:(fun a -> ignore (Backend.from_host step.context a : bool));
Train.run step;
Tensor.print ~with_code:false ~with_grad:false `Default g;
[%expect
Expand Down Expand Up @@ -89,13 +88,6 @@ let%expect_test "Micrograd half-moons example" =
(* Note: for as-yet unknown reason, this test can lead to different resuls on different versions
of dependencies. *)
let module Backend = (val Arrayjit.Backends.fresh_backend ~backend_name:"cc" ()) in
let backend =
(module Backend : Backend
with type buffer_ptr = Backend.buffer_ptr
and type dev = Backend.dev
and type runner = Backend.runner
and type event = Backend.event)
in
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
let ctx = Backend.make_context stream in
let open Operation.At in
Expand Down Expand Up @@ -148,8 +140,6 @@ let%expect_test "Micrograd half-moons example" =
let sgd_routine =
Train.to_routine (module Backend) ctx bindings (Asgns.sequence [ update.fwd_bprop; sgd ])
in
Train.all_host_to_device backend sgd_routine.context scalar_loss;
Train.all_host_to_device backend sgd_routine.context learning_rate;
let step_ref = IDX.find_exn sgd_routine.bindings step_n in
step_ref := 0;
for _epoch = 1 to epochs do
Expand Down Expand Up @@ -180,7 +170,6 @@ let%expect_test "Micrograd half-moons example" =
Tn.set_values point.value [| x; y |];
(* For the gccjit backend, point is only on host, not on device. For cuda, this will be
needed. *)
assert (Backend.from_host result_routine.context point.value);
Train.run result_routine;
Float.(mlp_result.@[0] >= 0.)
in
Expand Down
5 changes: 0 additions & 5 deletions test/zero2hero_1of7.ml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ let%expect_test "Graph drawing recompile" =
Train.every_non_literal_on_host f;
let f_upd = Train.grad_update f in
let f_bprop = Train.to_routine (module Backend) ctx IDX.empty f_upd.fwd_bprop in
Tensor.iter_embedded f ~f:(fun a -> ignore (Backend.from_host f_bprop.context a : bool));
Train.run f_bprop;
Tensor.print_tree ~with_grad:true ~depth:9 f;
[%expect
Expand Down Expand Up @@ -279,7 +278,6 @@ let%expect_test "Simple gradients hosted" =
|}];
(* Do not update the params: all values and gradients will be at initial points, which are
specified in the tensor in the brackets. *)
Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.from_host grad_routine.context a : bool));
Train.run grad_routine;
Tensor.print_tree ~with_grad:true ~depth:9 l;
[%expect
Expand Down Expand Up @@ -410,7 +408,6 @@ let%expect_test "Simple gradients virtual" =
|}];
(* Do not update the params: all values and gradients will be at initial points, which are
specified in the tensor in the brackets. *)
Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.from_host grad_routine.context a : bool));
Train.run grad_routine;
Tensor.print_tree ~with_grad:true ~depth:9 l;
[%expect
Expand Down Expand Up @@ -497,7 +494,6 @@ let%expect_test "2D neuron hosted" =
Train.every_non_literal_on_host v;
let update = Train.grad_update v in
let routine = Train.to_routine (module Backend) ctx IDX.empty update.fwd_bprop in
Tensor.iter_embedded v ~f:(fun a -> ignore (Backend.from_host routine.context a : bool));
Train.run routine;
Tensor.print_tree ~with_grad:true ~depth:9 v;
[%expect
Expand Down Expand Up @@ -525,7 +521,6 @@ let%expect_test "2D neuron virtual" =
let%op v = ("w" [ (-3, 1) ] * "x" [ 2; 0 ]) + "b" [ 6.7 ] in
let update = Train.grad_update v in
let routine = Train.to_routine (module Backend) ctx IDX.empty update.fwd_bprop in
Tensor.iter_embedded v ~f:(fun a -> ignore (Backend.from_host routine.context a : bool));
Train.run routine;
Tensor.print_tree ~with_grad:true ~depth:9 v;
[%expect
Expand Down

0 comments on commit 6d41b75

Please sign in to comment.