-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmoons_benchmark.ml
301 lines (285 loc) · 11.7 KB
/
moons_benchmark.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
open Base
open Ocannl
module Nd = Arrayjit.Ndarray
module Ops = Arrayjit.Ops
module Tn = Arrayjit.Tnode
module IDX = Train.IDX
module TDSL = Operation.TDSL
module NTDSL = Operation.NTDSL
module CDSL = Train.CDSL
module Utils = Arrayjit.Utils
module Rand = Arrayjit.Rand.Lib
module Debug_runtime = Utils.Debug_runtime
let _get_local_debug_runtime = Arrayjit.Utils._get_local_debug_runtime
[%%global_debug_log_level 9]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~backend_name
~value_prec ~grad_prec () =
[%track_sexp
let _debug : string = "started" in
(fun (started : unit) -> started) ()];
(* ignore seed; *)
let bench_title =
[%string
"seed %{seed#Int}, inline %{inlining_cutoff#Int}, parallel %{num_streams#Int}, batch \
%{batch_size#Int}, backend %{backend_name}, val prec %{Ops.prec_string value_prec}, grad \
prec %{Ops.prec_string grad_prec}"]
in
Stdio.printf "\n*** %s ***\n%!" bench_title;
CDSL.virtualize_settings.enable_device_only <- on_device;
CDSL.virtualize_settings.max_visits <- inlining_cutoff;
Tensor.default_value_prec := value_prec;
Tensor.default_grad_prec := grad_prec;
Utils.settings.output_debug_files_in_build_directory <- true;
(* This will only log from routines if log-level is high enough. *)
Utils.settings.debug_log_from_routines <- true;
Rand.init (* seed *) 0;
let hid_dim_1 = 16 in
let hid_dim_2 = 8 in
let hid_dim_3 = 4 in
(* TINY for debugging: *)
(* let hid_dim = 2 in *)
(* let hid_dim = 4 in *)
let data_len = 3 * 5 * 1024 in
(* TINY for debugging: *)
(* let data_len = 3 * 4 in *)
(* let data_len = 3 * 8 in *)
let flat_len = data_len / 2 in
(* Note: [minibatch_size = batch_size / num_streams] is the actual per-device batch used. *)
(* let epochs = 400 in *)
let epochs = 100 in
(* let epochs = 50 in *)
(* TINY for debugging: *)
(* let epochs = 3 in *)
(* let epochs = 2 in *)
(* let epochs = 1 in *)
(* let init_lr = 0.1 in *)
let init_lr = 0.01 in
let noise () = Rand.float_range (-0.1) 0.1 in
let moons_flat =
Array.concat_map (Array.create ~len:flat_len ())
~f:
Float.(
fun () ->
let i = Rand.int flat_len in
let v = of_int i * pi / of_int flat_len in
let c = cos v and s = sin v in
[| c + noise (); s + noise (); 1.0 - c + noise (); 0.5 - s + noise () |])
in
let moons_flat ~b = TDSL.init_const ~l:"moons_flat" ~b ~o:[ 2 ] moons_flat in
let moons_classes = Array.init data_len ~f:(fun i -> if i % 2 = 0 then 1. else -1.) in
let moons_classes ~b = TDSL.init_const ~l:"moons_classes" ~b ~o:[ 1 ] moons_classes in
let init_time = Time_now.nanoseconds_since_unix_epoch () in
let%op mlp x =
"w4"
* ?/("b3" hid_dim_3 + ("w3" * ?/("b2" hid_dim_2 + ("w2" * ?/("b1" hid_dim_1 + ("w1" * x))))))
in
(* TINY for debugging: *)
(* let%op mlp x = "w2" * ?/("b1" hid_dim + ("w1" * x)) in *)
let%op loss_fn ~output ~expectation = ?/(!..1 - (expectation *. output)) in
let start_time = ref None in
let weight_decay = 0.0002 in
Arrayjit.Schedulers.sync_suggested_num_streams := num_streams;
let module Backend = (val Arrayjit.Backends.fresh_backend ~backend_name ()) in
Stdlib.Format.printf "Initial backend global debug info: %a\n%!" Sexp.pp_hum
@@ Backend.get_global_debug_info ();
let per_batch_callback ~at_batch:_ ~at_step:_ ~learning_rate:_ ~batch_loss:_ ~epoch_loss:_ =
(* Stdio.printf "Batch=%d, step=%d, lr=%f, batch loss=%f, epoch loss=%f\n%!" at_batch at_step
learning_rate batch_loss epoch_loss; *)
if Option.is_none !start_time then start_time := Some (Time_now.nanoseconds_since_unix_epoch ())
in
(* Tn.print_accessible_headers (); *)
let per_epoch_callback ~at_step ~at_epoch ~learning_rate ~epoch_loss =
if at_epoch % 10 = 9 then
Stdio.printf "Epoch=%d, step=%d, lr=%f, epoch loss=%f\n%!" at_epoch at_step learning_rate
epoch_loss
in
Backend.initialize Train.BT.Most_parallel_streams;
let {
Train.inputs;
outputs;
model_result;
infer_callback;
rev_batch_losses;
rev_epoch_losses;
learning_rates;
used_memory;
} =
Train.example_train_loop ~seed ~batch_size ~init_lr ~max_num_streams:num_streams ~data_len
~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn ~weight_decay
~per_batch_callback ~per_epoch_callback ~per_epoch_debug_streams:false
(module Backend)
()
in
let points = Tn.points_2d ~xdim:0 ~ydim:1 inputs.value in
let classes = Tn.points_1d ~xdim:0 outputs.value in
let points1, points2 = Array.partitioni_tf points ~f:Float.(fun i _ -> classes.(i) > 0.) in
Stdio.print_endline "\n******** mlp_result **********";
Tensor.print_tree ~with_id:true ~with_grad:false ~depth:9 model_result;
Stdio.printf "\n********\n%!";
(* Arrayjit.Tnode.print_accessible_headers (); *)
let callback (x, y) = Float.((infer_callback [| x; y |]).(0) >= 0.) in
let%track3_sexp plot_moons () =
(* [%log_level 0; *)
PrintBox_utils.plot
(* TINY for debugging: *)
(* ~small:true *)
~as_canvas:true
[
Scatterplot { points = points1; content = PrintBox.line "#" };
Scatterplot { points = points2; content = PrintBox.line "%" };
Boundary_map
{ content_false = PrintBox.line "."; content_true = PrintBox.line "*"; callback };
]
(* ] *)
in
Stdio.printf "\nHalf-moons scatterplot and decision boundary:\n%!";
PrintBox_text.output Stdio.stdout @@ plot_moons ();
Stdio.printf "\nBatch Log-loss:\n%!";
let plot_loss =
PrintBox_utils.plot ~x_label:"step" ~y_label:"batch log loss"
[
Line_plot
{
points =
Array.of_list_rev_map rev_batch_losses ~f:Float.(fun x -> max (log 0.00003) (log x));
content = PrintBox.line "-";
};
]
in
PrintBox_text.output Stdio.stdout plot_loss;
Stdio.printf "\nEpoch Log-loss:\n%!";
let plot_loss =
PrintBox_utils.plot ~x_label:"step" ~y_label:"epoch log loss"
[
Line_plot
{
points = Array.of_list_rev_map rev_epoch_losses ~f:Float.log;
content = PrintBox.line "-";
};
]
in
PrintBox_text.output Stdio.stdout plot_loss;
Stdio.printf "\nLearning rate:\n%!";
let plot_lr =
PrintBox_utils.plot ~x_label:"step" ~y_label:"learning rate"
[ Line_plot { points = Array.of_list_rev learning_rates; content = PrintBox.line "-" } ]
in
PrintBox_text.output Stdio.stdout plot_lr;
let final_time = Time_now.nanoseconds_since_unix_epoch () in
(* TODO: include init time in benchmarks? *)
let init_time_in_sec =
Int63.(to_float @@ (Option.value_exn ~here:[%here] !start_time - init_time)) /. 1000_000_000.
in
let time_in_sec =
Int63.(to_float @@ (final_time - Option.value_exn ~here:[%here] !start_time)) /. 1000_000_000.
in
Stdio.printf "\nTime in sec: %f\n%!" time_in_sec;
let result =
PrintBox_utils.Benchmark
{
bench_title;
time_in_sec;
mem_in_bytes = used_memory;
result_label = "init time in sec, min loss, last loss";
result =
[%sexp_of: float * float * float]
( init_time_in_sec,
List.reduce_exn rev_epoch_losses ~f:Float.min,
List.hd_exn rev_epoch_losses );
}
in
Stdio.printf "\n\n%!";
(* Arrayjit.Tnode.print_accessible_headers (); *)
Stdlib.Format.printf "Final backend global debug info: %a\n%!" Sexp.pp_hum
@@ Backend.get_global_debug_info ();
result
let _suspend () =
ignore
@@ classify_moons ~seed:0 ~on_device:true ~inlining_cutoff:3 ~num_streams:8 ~batch_size:16
~backend_name:"gccjit" ~value_prec:CDSL.single ~grad_prec:CDSL.double ()
let _cuda_benchmarks =
List.concat_map [ 1; 3; 6; 12; 16; 20 (* 32; 64 *) ] ~f:(fun num_streams ->
List.concat_map
[
(* TINY for debugging: *)
(* 3 * 2 *)
3 * 5 * 16 (* ; 3 * 5 * 32; 3 * 5 * 64 *);
]
~f:(fun batch_size ->
List.concat_map [ (* 0; 1; 2; *) 3 ] ~f:(fun inlining_cutoff ->
List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed ->
List.concat_map [ (* "gccjit" ; "cuda";"sync_cc" ; *) "cc" ]
~f:(fun backend_name ->
List.concat_map [ (* CDSL.double; *) CDSL.single (* ; CDSL.half *) ]
~f:(fun value_prec ->
[
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_streams
~batch_size ~backend_name ~value_prec ~grad_prec:value_prec;
]))))))
let _cuda_parallel_benchmarks =
List.concat_map
[
(* 1; *)
2;
(* 3; 4; 5; 6; 8; 10; 12; 16; 20 *)
(* 32; 64 *)
] ~f:(fun num_streams ->
List.concat_map
[
(* TINY for debugging: *)
3 * 4
(* 3 * 5 * 16 *)
(* ; 3 * 5 * 32 *);
]
~f:(fun batch_size ->
List.concat_map [ 0 (* 1; 2; 3 *) ] ~f:(fun inlining_cutoff ->
List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed ->
List.concat_map [ (* "gccjit"; "cuda" ;"cc"; *) "sync_cc" ]
~f:(fun backend_name ->
List.concat_map [ (* CDSL.double; *) CDSL.single (* ; CDSL.half *) ]
~f:(fun value_prec ->
[
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_streams
~batch_size ~backend_name ~value_prec ~grad_prec:value_prec;
]))))))
let _mem_benchmarks =
List.concat_map [ 1; 3; 6; 12; 16 (* ; 20; 32; 64 *) ] ~f:(fun num_streams ->
List.concat_map
[
(* TINY for debugging: *)
(* 3 * 2 *)
3 * 5 * 16 (* ; 3 * 5 * 32; 3 * 5 * 64 *);
]
~f:(fun batch_size ->
List.concat_map [ 0; (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff ->
List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed ->
List.concat_map [ (* "gccjit" ; *) "cc"; "cuda" ] ~f:(fun backend_name ->
List.concat_map [ (* CDSL.double; *) CDSL.single; CDSL.half ]
~f:(fun value_prec ->
[
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_streams
~batch_size ~backend_name ~value_prec ~grad_prec:value_prec;
]))))))
(* let time_of = function PrintBox_utils.Benchmark { time_in_sec; _ } -> time_in_sec let nth_best
nth bench = let results = List.init 5 ~f:(fun seed -> bench ~seed ()) in let sorted = List.sort
results ~compare:(fun r1 r2 -> Float.compare (time_of r1) (time_of r2)) in List.nth_exn sorted
(nth - 1) *)
let fixed_seed_search seed =
classify_moons ~seed ~on_device:true ~inlining_cutoff:3 ~num_streams:1 ~batch_size:20
~backend_name:"cuda" ~value_prec:CDSL.single ~grad_prec:CDSL.single ()
let _suspended () =
List.init 20 ~f:fixed_seed_search |> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout
(* let () = List.map benchmarks ~f:(nth_best 2) |> PrintBox_utils.table |> PrintBox_text.output
Stdio.stdout *)
let _suspended () =
[
classify_moons ~seed:7 ~on_device:true ~inlining_cutoff:0 ~num_streams:3 ~batch_size:240
~backend_name:"cc" ~value_prec:CDSL.half ~grad_prec:CDSL.half ();
]
|> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout
let benchmark benchmarks =
List.map benchmarks ~f:(fun bench -> bench ())
|> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout
let _suspended () = benchmark _cuda_parallel_benchmarks
let () = benchmark _cuda_benchmarks