From d4277b2a1f2f5d3473726ae1964a0face6d6c167 Mon Sep 17 00:00:00 2001 From: Lukasz Stafiniak Date: Wed, 11 Dec 2024 00:13:39 +0100 Subject: [PATCH] Fixes the memory model: on-host arrays can be in contexts Double check: not possible it would trigger freeing host array. Still broken: cc backend tests hang. --- arrayjit/lib/anatomy_of_a_backend.md | 13 +++- arrayjit/lib/assignments.ml | 11 +-- arrayjit/lib/backend_impl.ml | 12 +-- arrayjit/lib/backend_intf.ml | 2 - arrayjit/lib/backends.ml | 35 +++++---- arrayjit/lib/c_syntax.ml | 10 ++- arrayjit/lib/cc_backend.ml | 29 ++++--- arrayjit/lib/cuda_backend.cudajit.ml | 17 ++-- arrayjit/lib/gcc_backend.gccjit.ml | 28 ++++--- arrayjit/lib/lowered_backend_missing.ml | 3 +- arrayjit/lib/no_device_backend_missing.ml | 3 +- arrayjit/lib/tnode.ml | 94 +++++++++++------------ arrayjit/lib/utils.ml | 4 +- bin/moons_benchmark.ml | 62 +++++++-------- lib/attic.mld | 9 +++ lib/train.ml | 33 ++++---- 16 files changed, 196 insertions(+), 169 deletions(-) diff --git a/arrayjit/lib/anatomy_of_a_backend.md b/arrayjit/lib/anatomy_of_a_backend.md index defa9016..f464d651 100644 --- a/arrayjit/lib/anatomy_of_a_backend.md +++ b/arrayjit/lib/anatomy_of_a_backend.md @@ -57,7 +57,7 @@ In the future, when we introduce program search, `compile` functions will return OCANNL classifies tensor nodes according to their memory properties: ```ocaml -(** A possible algorithm for deciding sharing within a single device: + (** A possible algorithm for deciding sharing within a single device: - If a tensor node is read-only for a context, and not otherwise recorded, it is stored as a cross-stream sharing candidate. - If a cross-stream sharing candidate is read-only for another context, whose parent does not @@ -71,9 +71,14 @@ OCANNL classifies tensor nodes according to their memory properties: If a tensor node is shared cross-stream, within-device copying is a NOOP as source and destination pointers are in that case identical. *) type sharing = - | Unset + | Unset (** One of: [Per_stream], [Shared_cross_streams]. *) | Per_stream (** The tensor node has separate arrays for each stream. *) - | Shared_cross_stream (** The tensor node has a single array per device. *) + | Shared_cross_streams + (** The tensor node has a single array per device that can appear in multiple contexts, except + for backends with [Option.is_some use_host_memory] and nodes with memory mode already + [Hosted (Changed_on_devices Shared_cross_streams)] before first linking on a device, where + it only has the on-host array. In that case the on-host array is registered in the + context, to avoid misleading behavior from `device_to_device`. *) type memory_type = | Constant (** The tensor node does not change after initialization. *) @@ -110,6 +115,8 @@ A backend can make more refined distinctions, for example a `Local` node in CUDA Contexts track (or store) the on-device arrays corresponding to tensor nodes. Contexts form a hierarchy: linking takes a parent context and outputs a child context. Related contexts that use a tensor node must use the same on-device array for the tensor node. If two unrelated contexts are on the same device, i.e. have a common ancestor, and use the same tensor node that is not part of the most recent common ancestor, the behavior is undefined. +To avoid misleading behavior of `device_to_device` data movement, non-constant materialized tensor nodes are represented in contexts making use of them, even when the underlying array is on host. This way the logic remains the same regardless of whether a backend shares memory with the host. We are careful to not accidentally call `free_buffer` on hosted arrays. + ## Typical details of a backend implementation During the compilation process, the old context cannot be available when `compile` is handled. Currently, all backends generate context-and-device-independent kernels, that refer to context arrays via parameters. diff --git a/arrayjit/lib/assignments.ml b/arrayjit/lib/assignments.ml index 501c3f00..12ddae5b 100644 --- a/arrayjit/lib/assignments.ml +++ b/arrayjit/lib/assignments.ml @@ -77,16 +77,13 @@ let get_name_exn asgns = let is_total ~initialize_neutral ~projections = initialize_neutral && Indexing.is_bijective projections -(** Returns materialized nodes in the sense of {!Tnode.is_in_context}. NOTE: it should be called - after compilation and ideally after linking with the relevant contexts; otherwise, it is an - under-estimate. *) -let%debug3_sexp context_nodes ~(use_host_memory : bool) (asgns : t) : Tn.t_set = +(** Returns materialized nodes in the sense of {!Tnode.is_in_context_force}. NOTE: it must be called + after compilation; otherwise, it will disrupt memory mode inference. *) +let%debug3_sexp context_nodes ~(use_host_memory : 'a option) (asgns : t) : Tn.t_set = let open Utils.Set_O in let empty = Set.empty (module Tn) in let one tn = - if Option.value ~default:false @@ Tnode.is_in_context ~use_host_memory tn then - Set.singleton (module Tn) tn - else empty + if Tn.is_in_context_force ~use_host_memory tn 34 then Set.singleton (module Tn) tn else empty in let of_node = function Node rhs -> one rhs | Merge_buffer _ -> empty in let rec loop = function diff --git a/arrayjit/lib/backend_impl.ml b/arrayjit/lib/backend_impl.ml index 8574967e..13674f92 100644 --- a/arrayjit/lib/backend_impl.ml +++ b/arrayjit/lib/backend_impl.ml @@ -16,6 +16,8 @@ open Backend_intf module type No_device_buffer_and_copying = sig include Alloc_buffer with type stream := unit + val use_host_memory : (unit Ctypes.ptr -> buffer_ptr) option + val get_used_memory : unit -> int (** Returns (an upper bound of) the memory used for arrays, in bytes. *) @@ -28,6 +30,7 @@ module No_device_buffer_and_copying () : No_device_buffer_and_copying with type buffer_ptr = unit Ctypes.ptr = struct type buffer_ptr = unit Ctypes.ptr + let use_host_memory = Some Fn.id let sexp_of_buffer_ptr = Ops.sexp_of_voidptr include Buffer_types (struct @@ -70,8 +73,6 @@ module No_device_buffer_and_copying () : Ctypes_memory_stubs.memcpy ~dst:(Ndarray.get_fatptr_not_managed dst) ~src ~size:(Ndarray.size_in_bytes dst) - - let c_ptr_to_string = Some Ops.c_ptr_to_string end module Device_types (Device_config : Device_config) = struct @@ -133,10 +134,11 @@ end module type Backend_impl_common = sig include Buffer - val use_host_memory : bool - (** If true, the backend will read from and write to the host memory directly whenever possible. + val use_host_memory : (unit Ctypes.ptr -> buffer_ptr) option + (** If not [None], the backend will read from and write to the host memory directly whenever + reasonable. - [use_host_memory] can only be true on unified memory devices, like CPU and Apple Metal. *) + [use_host_memory] can only be [Some] on unified memory devices, like CPU and Apple Metal. *) end (** An interface to adding schedulers for stream-agnostic (typically CPU) backend implementations. diff --git a/arrayjit/lib/backend_intf.ml b/arrayjit/lib/backend_intf.ml index 9cb0199a..5c83ad91 100644 --- a/arrayjit/lib/backend_intf.ml +++ b/arrayjit/lib/backend_intf.ml @@ -18,8 +18,6 @@ end module type Buffer = sig type buffer_ptr [@@deriving sexp_of] - val c_ptr_to_string : (buffer_ptr -> Ops.prec -> string) option - include module type of Buffer_types (struct type nonrec buffer_ptr = buffer_ptr [@@deriving sexp_of] end) diff --git a/arrayjit/lib/backends.ml b/arrayjit/lib/backends.ml index 7085d8b4..96f8407a 100644 --- a/arrayjit/lib/backends.ml +++ b/arrayjit/lib/backends.ml @@ -90,7 +90,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin let ordinal_of ctx = ctx.stream.device.ordinal in let name_of ctx = Backend.(get_name ctx.stream) in let same_device = ordinal_of dst = ordinal_of src in - if same_device && (Tn.known_shared_cross_stream tn || String.equal (name_of src) (name_of dst)) + if same_device && (Tn.known_shared_cross_streams tn || String.equal (name_of src) (name_of dst)) then false else match Map.find src.ctx_arrays tn with @@ -187,8 +187,7 @@ let lower_batch_assignments ?names ?occupancy bindings asgns_l = let%debug3_sexp verify_prior_context ~use_host_memory ~ctx_arrays ~from_prior_context : unit = Set.iter from_prior_context ~f:(fun tn -> if - (* Err on the safe side. *) - Option.value ~default:false (Tn.is_in_context ~use_host_memory tn) + Tn.is_in_context_force ~use_host_memory tn 42 && not (Option.is_some @@ Map.find ctx_arrays tn) then raise @@ Utils.User_error ("The linked context lacks node " ^ Tnode.debug_name tn)) @@ -349,12 +348,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct } let%track3_sexp alloc_if_needed (stream : stream) ~key ~data:node ctx_arrays = - (* TODO: do we need this? *) - (* Tn.default_to_most_local key 345; *) - if - Option.value ~default:true (Tnode.is_in_context ~use_host_memory key) - && not (Map.mem ctx_arrays key) - then ( + if Tnode.is_in_context_force ~use_host_memory key 43 && not (Map.mem ctx_arrays key) then ( [%log Tn.debug_name key]; [%log (key : Tnode.t)]; let default () = @@ -362,14 +356,27 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct in let add_new () = Map.add_exn ctx_arrays ~key ~data:(default ()) in let device = stream.device in - if node.Low_level.read_only then + if node.Low_level.read_only then ( if Tn.known_non_cross_stream key then add_new () - else ( + else + let data = + match use_host_memory with + | None -> Hashtbl.find_or_add device.cross_stream_candidates key ~default + | Some get_buffer_ptr -> + if + (not (Hashtbl.mem device.cross_stream_candidates key)) + && Tn.known_shared_cross_streams key && Tn.is_hosted_force key 44 + then + Hashtbl.update_and_return device.cross_stream_candidates key ~f:(fun _ -> + get_buffer_ptr @@ Ndarray.get_voidptr_not_managed + @@ Option.value_exn ~here:[%here] + @@ Lazy.force key.array) + else Hashtbl.find_or_add device.cross_stream_candidates key ~default + in if Hashtbl.mem device.cross_stream_candidates key then - Tn.update_memory_sharing key Tn.Shared_cross_stream 39; - let data = Hashtbl.find_or_add device.cross_stream_candidates key ~default in + Tn.update_memory_sharing key Tn.Shared_cross_streams 39; Map.add_exn ctx_arrays ~key ~data) - else if Tn.known_shared_cross_stream key then ( + else if Tn.known_shared_cross_streams key then ( if Hashtbl.mem device.owner_stream key then ( if not (equal_stream stream (Hashtbl.find_exn device.owner_stream key)) then raise diff --git a/arrayjit/lib/c_syntax.ml b/arrayjit/lib/c_syntax.ml index 9bbc0ad1..cc1f83de 100644 --- a/arrayjit/lib/c_syntax.ml +++ b/arrayjit/lib/c_syntax.ml @@ -15,7 +15,9 @@ module C_syntax (B : sig (** The low-level prcedure to compile, and the arrays of the context it will be linked to if not shared and already known. *) - val use_host_memory : bool + type buffer_ptr + + val use_host_memory : (unit Ctypes.ptr -> buffer_ptr) option val logs_to_stdout : bool val main_kernel_prefix : string val kernel_prep_line : string @@ -29,7 +31,7 @@ struct let get_ident = Low_level.get_ident_within_code ~no_dots:true @@ Array.map B.procs ~f:(fun l -> l.llc) - let in_ctx tn = B.(Tn.is_in_context ~use_host_memory tn) + let in_ctx tn = B.(Tn.is_in_context_force ~use_host_memory tn 46) let pp_zero_out ppf tn = Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " (get_ident tn) @@ Tn.size_in_bytes tn @@ -268,8 +270,8 @@ struct (* A rough approximation to the type Gccjit_backend.mem_properties. *) let backend_info, is_param = if Tn.is_virtual_force tn 334 then ("Virt", false) - else if Option.value ~default:false @@ in_ctx tn then ("Ctx", true) - else if Tn.is_materialized_force tn 335 then ("Global or ctx", true) + else if in_ctx tn then ("Ctx", true) + else if Tn.is_materialized_force tn 335 then ("Global", true) else if Tn.known_not_materialized tn then ("Local", false) else assert false in diff --git a/arrayjit/lib/cc_backend.ml b/arrayjit/lib/cc_backend.ml index 0d5cde2a..15474ca8 100644 --- a/arrayjit/lib/cc_backend.ml +++ b/arrayjit/lib/cc_backend.ml @@ -33,8 +33,6 @@ type procedure = { } [@@deriving sexp_of] -let use_host_memory = true - let get_global_run_id = let next_id = ref 0 in fun () -> @@ -76,6 +74,9 @@ module C_syntax_config (Input : sig end) = struct let procs = Input.procs + + type nonrec buffer_ptr = buffer_ptr + let use_host_memory = use_host_memory let logs_to_stdout = false let main_kernel_prefix = "" @@ -127,14 +128,6 @@ let%diagn_sexp compile_batch ~names bindings (lowereds : Low_level.optimized opt let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : procedure) = let name : string = code.name in - List.iter code.params ~f:(function - | _, Param_ptr tn when not @@ Tn.known_shared_with_host ~use_host_memory tn -> - if not (Map.mem ctx_arrays tn) then - invalid_arg - [%string - "Cc_backend.link_compiled: node %{Tn.debug_name tn} missing from context: \ - %{Tn.debug_memory_mode tn.Tn.memory_mode}"] - | _ -> ()); let log_file_name = Utils.diagn_log_file [%string "debug-%{runner_label}-%{code.name}.log"] in let run_variadic = [%log_level @@ -158,11 +151,16 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro Param_2f (get_ptr, merge_buffer, link bs ps Ctypes.(ptr void @-> cs)) | bs, Param_ptr tn :: ps -> let c_ptr = - if Tn.known_shared_with_host ~use_host_memory tn then - Ndarray.get_voidptr_not_managed - @@ Option.value_exn ~here:[%here] - @@ Lazy.force tn.array - else Map.find_exn ctx_arrays tn + match Map.find ctx_arrays tn with + | None -> + Ndarray.get_voidptr_not_managed + @@ Option.value_exn ~here:[%here] + ~message: + [%string + "Cc_backend.link_compiled: node %{Tn.debug_name tn} missing from \ + context: %{Tn.debug_memory_mode tn.Tn.memory_mode}"] + @@ Lazy.force tn.array + | Some arr -> arr in Param_2 (ref (Some c_ptr), link bs ps Ctypes.(ptr void @-> cs)) in @@ -174,6 +172,7 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro in let%diagn_l_sexp work () : unit = [%log_result name]; + (* Stdio.printf "launching %s\n" name; *) Indexing.apply run_variadic (); if Utils.debug_log_from_routines () then ( Utils.log_trace_tree (Stdio.In_channel.read_lines log_file_name); diff --git a/arrayjit/lib/cuda_backend.cudajit.ml b/arrayjit/lib/cuda_backend.cudajit.ml index 8dd80a02..eafec59b 100644 --- a/arrayjit/lib/cuda_backend.cudajit.ml +++ b/arrayjit/lib/cuda_backend.cudajit.ml @@ -26,14 +26,13 @@ module Backend_buffer = struct type buffer_ptr = Cu.Deviceptr.t let sexp_of_buffer_ptr ptr = Sexp.Atom (Cu.Deviceptr.string_of ptr) - let c_ptr_to_string = None include Buffer_types (struct type nonrec buffer_ptr = buffer_ptr [@@deriving sexp_of] end) end -let use_host_memory = false +let use_host_memory = None module Device_config = struct include Backend_buffer @@ -156,15 +155,18 @@ let%track3_sexp new_stream (device : device) : stream = let cuda_properties = let cache = - lazy - (Array.init (num_devices ()) ~f:(fun ordinal -> - let dev = get_device ~ordinal in - lazy (Cu.Device.get_attributes dev.dev.dev))) + let%debug2_sexp f (ordinal : int) = + let dev = get_device ~ordinal in + lazy (Cu.Device.get_attributes dev.dev.dev) + in + lazy (Array.init (num_devices ()) ~f) in - fun device -> + let%debug2_sexp get_props (device : device) : Cu.Device.attributes = if not @@ is_initialized () then invalid_arg "cuda_properties: CUDA not initialized"; let cache = Lazy.force cache in Lazy.force cache.(device.ordinal) + in + get_props let suggested_num_streams device = match !global_config with @@ -427,6 +429,7 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx (* Map.iteri ctx_arrays ~f:(fun ~key ~data:ptr -> if key.Low_level.zero_initialized then Cu.Stream.memset_d8 ptr Unsigned.UChar.zero ~length:(Tn.size_in_bytes key.Low_level.tn)); *) [%log "launching the kernel"]; + (* Stdio.printf "launching %s\n" name; *) (if Utils.debug_log_from_routines () then Utils.add_log_processor ~prefix:log_id_prefix @@ fun _output -> [%log_block diff --git a/arrayjit/lib/gcc_backend.gccjit.ml b/arrayjit/lib/gcc_backend.gccjit.ml index 3302e75c..124f948e 100644 --- a/arrayjit/lib/gcc_backend.gccjit.ml +++ b/arrayjit/lib/gcc_backend.gccjit.ml @@ -66,7 +66,7 @@ type procedure = { } [@@deriving sexp_of] -let use_host_memory = true +(* let use_host_memory = true *) let gcc_typ_of_prec = let open Gccjit in @@ -114,14 +114,14 @@ let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~param_ptrs ini let ptr_typ = Type.pointer num_typ in let ident = get_ident tn in let hosted = Tn.is_hosted_force tn 344 in - let in_ctx = Tn.is_in_context ~use_host_memory tn in + let in_ctx = Tn.is_in_context_force ~use_host_memory tn 45 in let ptr = match (in_ctx, hosted) with - | Some true, _ -> + | true, _ -> let p = Param.create ctx ptr_typ ident in param_ptrs := (p, Param_ptr tn) :: !param_ptrs; Lazy.from_val (RValue.param p) - | (Some false | None), true -> ( + | false, true -> ( let addr arr = Lazy.from_val @@ get_c_ptr ctx num_typ @@ Ctypes.bigarray_start Ctypes_static.Genarray arr in @@ -131,7 +131,7 @@ let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~param_ptrs ini | Some (Single_nd arr) -> addr arr | Some (Double_nd arr) -> addr arr | None -> assert false) - | (Some false | None), false -> + | false, false -> let arr_typ = Type.array ctx num_typ size_in_elems in let v = ref None in let initialize _init_block func = v := Some (Function.local func arr_typ ident) in @@ -644,7 +644,8 @@ let%diagn_sexp compile_batch ~(names : string option array) bindings let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : procedure) = let name : string = code.name in List.iter code.params ~f:(function - | Param_ptr tn when not (Tn.known_shared_cross_stream tn) -> assert (Map.mem ctx_arrays tn) + (* FIXME: see cc_backend.ml *) + | Param_ptr tn when not (Tn.known_shared_cross_streams tn) -> assert (Map.mem ctx_arrays tn) | _ -> ()); let log_file_name = Utils.diagn_log_file [%string "debug-%{runner_label}-%{code.name}.log"] in let run_variadic = @@ -667,11 +668,16 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro Param_1 (ref (Some log_file_name), link bs ps Ctypes.(string @-> cs)) | bs, Param_ptr tn :: ps -> let c_ptr = - if Tn.known_shared_with_host ~use_host_memory tn then - Ndarray.get_voidptr_not_managed - @@ Option.value_exn ~here:[%here] - @@ Lazy.force tn.array - else Map.find_exn ctx_arrays tn + match Map.find ctx_arrays tn with + | None -> + Ndarray.get_voidptr_not_managed + @@ Option.value_exn ~here:[%here] + ~message: + [%string + "Gcc_backend.link_compiled: node %{Tn.debug_name tn} missing from \ + context: %{Tn.debug_memory_mode tn.Tn.memory_mode}"] + @@ Lazy.force tn.array + | Some arr -> arr in Param_2 (ref (Some c_ptr), link bs ps Ctypes.(ptr void @-> cs)) | bs, Merge_buffer :: ps -> diff --git a/arrayjit/lib/lowered_backend_missing.ml b/arrayjit/lib/lowered_backend_missing.ml index 57205070..5e5217d7 100644 --- a/arrayjit/lib/lowered_backend_missing.ml +++ b/arrayjit/lib/lowered_backend_missing.ml @@ -3,7 +3,7 @@ type dev type runner type event -let use_host_memory = false +let use_host_memory = None let sexp_of_dev _dev = failwith "Backend missing -- install the corresponding library" let sexp_of_runner _runner = failwith "Backend missing -- install the corresponding library" let sexp_of_event _event = failwith "Backend missing -- install the corresponding library" @@ -39,7 +39,6 @@ let make_child ?ctx_arrays:_ _context = let get_name _stream = failwith "Backend missing -- install the corresponding library" let sexp_of_buffer_ptr _buffer_ptr = failwith "Backend missing -- install the corresponding library" -let c_ptr_to_string = None type nonrec buffer = buffer_ptr Backend_intf.buffer diff --git a/arrayjit/lib/no_device_backend_missing.ml b/arrayjit/lib/no_device_backend_missing.ml index f5f6f7b9..6d8805ee 100644 --- a/arrayjit/lib/no_device_backend_missing.ml +++ b/arrayjit/lib/no_device_backend_missing.ml @@ -1,6 +1,6 @@ type buffer_ptr -let use_host_memory = false +let use_host_memory = None let initialize _config = failwith "Backend missing -- install the corresponding library" let is_initialized () = failwith "Backend missing -- install the corresponding library" let name = "Backend missing" @@ -19,7 +19,6 @@ let link_compiled ~merge_buffer:_ ~runner_label:_ _ctx_arrays _procedure = failwith "Backend missing -- install the corresponding library" let sexp_of_buffer_ptr _buffer_ptr = failwith "Backend missing -- install the corresponding library" -let c_ptr_to_string = None type nonrec buffer = buffer_ptr Backend_intf.buffer diff --git a/arrayjit/lib/tnode.ml b/arrayjit/lib/tnode.ml index cc45d6ac..52b09810 100644 --- a/arrayjit/lib/tnode.ml +++ b/arrayjit/lib/tnode.ml @@ -22,13 +22,14 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime If a tensor node is shared cross-stream, within-device copying is a NOOP as source and destination pointers are in that case identical. *) type sharing = - | Unset + | Unset (** One of: [Per_stream], [Shared_cross_streams]. *) | Per_stream (** The tensor node has separate arrays for each stream. *) - | Shared_cross_stream + | Shared_cross_streams (** The tensor node has a single array per device that can appear in multiple contexts, except - for backends with [use_host_memory = true] and nodes with memory mode - [Hosted (Changed_on_devices Shared_cross_stream)], where it only has the on-host array and - does not appear in any contexts. *) + for backends with [Option.is_some use_host_memory] and nodes with memory mode already + [Hosted (Changed_on_devices Shared_cross_streams)] before first linking on a device, where + it only has the on-host array. In that case the on-host array is registered in the + context, to avoid misleading behavior from `device_to_device`. *) [@@deriving sexp, compare, equal] type memory_type = @@ -129,14 +130,14 @@ let debug_memory_mode = function | Device_only -> "Dev" | Materialized -> "Material" | On_device Unset -> "On-dev" - | On_device Shared_cross_stream -> "Dev-shared" + | On_device Shared_cross_streams -> "Dev-shared" | On_device Per_stream -> "Dev-stream" | Hosted Constant -> "Host-const" | Hosted Nonconstant -> "Host-non-const" | Hosted Volatile -> "Hosted" | Hosted (Changed_on_devices Unset) -> "Host&dev" | Hosted (Changed_on_devices Per_stream) -> "Host&stream" - | Hosted (Changed_on_devices Shared_cross_stream) -> "Host&shared") + | Hosted (Changed_on_devices Shared_cross_streams) -> "Host&shared") ^ "/" ^ Int.to_string prov let log_debug_info ~from_log_level tn = @@ -168,49 +169,42 @@ let default_to_most_local tn provenance = | Some ((Virtual | Local | On_device _ | Hosted _), _) -> () let is_virtual_force tn provenance = - default_to_most_local tn provenance; - match tn.memory_mode with Some (Virtual, _) -> true | _ -> false - -let is_hosted_force ?specifically tn provenance = - default_to_most_local tn provenance; - match (tn.memory_mode, specifically) with - | None, _ -> assert false - | Some ((Virtual | Local | Device_only | On_device _), _), _ -> false - | Some (Hosted _, _), None -> true - | Some (Hosted memtyp, _), Some query -> equal_memory_type memtyp query - | Some ((Never_virtual | Materialized | Effectively_constant), _), _ -> assert false - -let is_materialized_force tn provenance = - default_to_most_local tn provenance; + match tn.memory_mode with + | Some (Virtual, _) -> true + | None | Some (Effectively_constant, _) -> + tn.memory_mode <- Some (Virtual, provenance); + true + | _ -> false + +let rec is_hosted_force tn provenance = + match tn.memory_mode with + | Some ((Virtual | Local | Device_only | On_device _), _) -> false + | Some (Hosted _, _) -> true + | None | Some ((Never_virtual | Materialized | Effectively_constant), _) -> + default_to_most_local tn provenance; + is_hosted_force tn provenance + +let rec is_materialized_force tn provenance = match tn.memory_mode with | None -> assert false | Some ((Virtual | Local), _) -> false | Some ((On_device _ | Hosted _ | Materialized), _) -> true - | Some ((Never_virtual | Device_only | Effectively_constant), _) -> assert false + | Some ((Never_virtual | Device_only | Effectively_constant), _) -> + default_to_most_local tn provenance; + is_materialized_force tn provenance -(* Unlike the [known_] functions which can only change from [false] to [true], [is_in_context - ~use_host_memory tn] is more precise. Generally, it can only change away from [None]. *) -let%debug3_sexp is_in_context ~(use_host_memory : bool) (tn : t) : bool option = - match tn.memory_mode with - | Some (Hosted (Changed_on_devices Per_stream), _) -> Some true - | Some ((Materialized | Hosted Nonconstant), _) when not use_host_memory -> Some true - | Some (Hosted (Constant | Volatile | Changed_on_devices Shared_cross_stream), _) - when use_host_memory -> - Some false - | Some (Hosted Nonconstant, _) when use_host_memory -> None - | Some (Hosted _, _) -> Some true - | Some ((Virtual | Local), _) -> Some false - | None | Some ((Materialized | Effectively_constant | Never_virtual | Device_only), _) -> None - | Some (On_device _, _) -> Some true - -(** The opposite of [is_in_context] for hosted tensor nodes. False if [use_host_memory = false] or - for non-hosted tensor nodes. *) -let known_shared_with_host ~use_host_memory tn = +let%debug3_sexp rec is_in_context_force ~(use_host_memory : 'a option) (tn : t) (provenance : int) : + bool = match tn.memory_mode with - | Some (Hosted (Constant | Volatile | Changed_on_devices Shared_cross_stream), _) - when use_host_memory -> - true - | _ -> false + | Some (Hosted (Changed_on_devices Per_stream), _) -> true + | Some ((Materialized | Hosted Nonconstant), _) when Option.is_none use_host_memory -> true + | Some (Hosted (Constant | Volatile), _) when Option.is_some use_host_memory -> false + | Some (Hosted _, _) -> true + | Some ((Virtual | Local), _) -> false + | None | Some ((Materialized | Effectively_constant | Never_virtual | Device_only), _) -> + default_to_most_local tn provenance; + is_in_context_force ~use_host_memory tn provenance + | Some (On_device _, _) -> true let known_not_materialized tn = match tn.memory_mode with Some ((Virtual | Local), _) -> true | _ -> false @@ -234,11 +228,11 @@ let known_not_param tn = true | _ -> false -let known_shared_cross_stream tn = +let known_shared_cross_streams tn = match tn.memory_mode with | Some - ( ( On_device Shared_cross_stream - | Hosted (Constant | Volatile | Changed_on_devices Shared_cross_stream) ), + ( ( On_device Shared_cross_streams + | Hosted (Constant | Volatile | Changed_on_devices Shared_cross_streams) ), _ ) -> true | _ -> false @@ -299,7 +293,7 @@ let update_memory_mode tn mode provenance = let update_memory_sharing tn sharing provenance = match (tn.memory_mode, sharing) with | None, _ -> tn.memory_mode <- Some (On_device sharing, provenance) - | Some (On_device Shared_cross_stream, _), Shared_cross_stream + | Some (On_device Shared_cross_streams, _), Shared_cross_streams | Some (On_device Per_stream, _), Per_stream -> () | Some ((On_device Unset | Device_only | Materialized), _), _ -> @@ -311,10 +305,10 @@ let update_memory_sharing tn sharing provenance = "Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \ %{debug_name tn} (hosted) -- currently hosted nodes not changed on devices must be \ shared cross-stream"] - | Some (Hosted (Changed_on_devices Shared_cross_stream), _), Shared_cross_stream + | Some (Hosted (Changed_on_devices Shared_cross_streams), _), Shared_cross_streams | Some (Hosted (Changed_on_devices Per_stream), _), Per_stream -> () - | Some (Hosted (Constant | Volatile), _), Shared_cross_stream -> () + | Some (Hosted (Constant | Volatile), _), Shared_cross_streams -> () | Some (Hosted (Nonconstant | Changed_on_devices Unset), _), _ -> tn.memory_mode <- Some (Hosted (Changed_on_devices sharing), provenance) | Some (_, prov2), Unset -> diff --git a/arrayjit/lib/utils.ml b/arrayjit/lib/utils.ml index 184dd4e8..f60e5bd4 100644 --- a/arrayjit/lib/utils.ml +++ b/arrayjit/lib/utils.ml @@ -310,7 +310,7 @@ let get_debug name = Minidebug_runtime.forget_printbox @@ Minidebug_runtime.debug ~time_tagged ~elapsed_times ~location_format ~print_entry_ids ~verbose_entry_ids ~global_prefix:name ~toc_entry ~toc_specific_hyperlink:"" - ~highlight_terms:Re.(alt [ str "await" ]) + ~highlight_terms:Re.(alt [ str "float *w2" ]) ~exclude_on_path:Re.(str "env") ~log_level ?snapshot_every_sec () | Some filename -> @@ -319,7 +319,7 @@ let get_debug name = ~print_entry_ids ~verbose_entry_ids ~global_prefix:name ~toc_flame_graph:true ~flame_graph_separation:50 ~toc_entry ~for_append:false ~max_inline_sexp_length:120 ~hyperlink ~toc_specific_hyperlink:"" - ~highlight_terms:Re.(alt [ str "await" ]) + ~highlight_terms:Re.(alt [ str "float *w2" ]) ~exclude_on_path:Re.(str "env") ~backend ~log_level ?snapshot_every_sec filename diff --git a/bin/moons_benchmark.ml b/bin/moons_benchmark.ml index 87ad0525..9c8e9b8e 100644 --- a/bin/moons_benchmark.ml +++ b/bin/moons_benchmark.ml @@ -34,7 +34,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b Tensor.default_grad_prec := grad_prec; Utils.settings.output_debug_files_in_build_directory <- true; (* This will only log from routines if log-level is high enough. *) - (* Utils.settings.debug_log_from_routines <- true; *) + Utils.settings.debug_log_from_routines <- true; Rand.init (* seed *) 0; let hid_dim_1 = 16 in let hid_dim_2 = 8 in @@ -45,14 +45,14 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b let data_len = 3 * 5 * 1024 in (* TINY for debugging: *) (* let data_len = 3 * 4 in *) - (* let data_len = 3 * 16 in *) + (* let data_len = 3 * 8 in *) let flat_len = data_len / 2 in (* Note: [minibatch_size = batch_size / num_streams] is the actual per-device batch used. *) (* let epochs = 400 in *) - (* let epochs = 100 in *) + let epochs = 100 in (* let epochs = 50 in *) (* TINY for debugging: *) - let epochs = 3 in + (* let epochs = 3 in *) (* let epochs = 2 in *) (* let epochs = 1 in *) (* let init_lr = 0.1 in *) @@ -121,22 +121,22 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b Stdio.print_endline "\n******** mlp_result **********"; Tensor.print_tree ~with_id:true ~with_grad:false ~depth:9 model_result; Stdio.printf "\n********\n%!"; - Arrayjit.Tnode.print_accessible_headers (); + (* Arrayjit.Tnode.print_accessible_headers (); *) let callback (x, y) = Float.((infer_callback [| x; y |]).(0) >= 0.) in let%track3_sexp plot_moons () = - [%log_level - 0; - let open PrintBox_utils in - plot - ~size:(120, 40) - (* TINY for debugging: *) - (* ~size:(20, 10) *) - ~x_label:"ixes" ~y_label:"ygreks" - [ - Scatterplot { points = points1; pixel = "#" }; - Scatterplot { points = points2; pixel = "%" }; - Boundary_map { pixel_false = "."; pixel_true = "*"; callback }; - ]] + (* [%log_level 0; *) + let open PrintBox_utils in + plot + ~size:(120, 40) + (* TINY for debugging: *) + (* ~size:(20, 10) *) + ~x_label:"ixes" ~y_label:"ygreks" + [ + Scatterplot { points = points1; pixel = "#" }; + Scatterplot { points = points2; pixel = "%" }; + Boundary_map { pixel_false = "."; pixel_true = "*"; callback }; + ] + (* ] *) in Stdio.printf "\nHalf-moons scatterplot and decision boundary:\n%!"; PrintBox_text.output Stdio.stdout @@ plot_moons (); @@ -190,7 +190,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b } in Stdio.printf "\n\n%!"; - Arrayjit.Tnode.print_accessible_headers (); + (* Arrayjit.Tnode.print_accessible_headers (); *) Stdlib.Format.printf "Final backend global debug info: %a\n%!" Sexp.pp_hum @@ Backend.get_global_debug_info (); result @@ -206,13 +206,13 @@ let _cuda_benchmarks = [ (* TINY for debugging: *) (* 3 * 2 *) - 3 * 5 * 16; - 3 * 5 * 32 (*; 3 * 5 * 64 *); + 3 * 5 * 16 (* ; 3 * 5 * 32; 3 * 5 * 64 *); ] ~f:(fun batch_size -> - List.concat_map [ 0; (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff -> + List.concat_map [ (* 0; 1; 2; *) 3 ] ~f:(fun inlining_cutoff -> List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed -> - List.concat_map [ (* "gccjit" ; "cc"; *) "cuda" ] ~f:(fun backend_name -> + List.concat_map [ (* "gccjit" ; "cuda";"sync_cc" ; *) "cc"] + ~f:(fun backend_name -> List.concat_map [ (* CDSL.double; *) CDSL.single (* ; CDSL.half *) ] ~f:(fun value_prec -> [ @@ -223,19 +223,20 @@ let _cuda_benchmarks = let _cuda_parallel_benchmarks = List.concat_map [ - (* 1; 2; *) - 3; - (* 4; 5; 6; 8; 10; 12; 16; 20 *) + (* 1; *) + 2; + (* 3; 4; 5; 6; 8; 10; 12; 16; 20 *) (* 32; 64 *) ] ~f:(fun num_streams -> List.concat_map [ (* TINY for debugging: *) - (* 3 * 4 *) - 3 * 5 * 16 (* ; 3 * 5 * 32 *); + 3 * 4 + (* 3 * 5 * 16 *) + (* ; 3 * 5 * 32 *); ] ~f:(fun batch_size -> - List.concat_map [ (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff -> + List.concat_map [ 0 (* 1; 2; 3 *) ] ~f:(fun inlining_cutoff -> List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed -> List.concat_map [ (* "gccjit"; "cuda" ;"cc"; *) "sync_cc" ] ~f:(fun backend_name -> @@ -291,4 +292,5 @@ let benchmark benchmarks = List.map benchmarks ~f:(fun bench -> bench ()) |> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout -let () = benchmark _cuda_parallel_benchmarks +let _suspended () = benchmark _cuda_parallel_benchmarks +let () = benchmark _cuda_benchmarks diff --git a/lib/attic.mld b/lib/attic.mld index b692cace..8fff54cc 100644 --- a/lib/attic.mld +++ b/lib/attic.mld @@ -242,6 +242,15 @@ let get_ptr ~(traced_store : Low_level.traced_store) ~ctx_nodes ~get_ident = | Local_only, _ -> ident) +let is_hosted_force ?specifically tn provenance = + default_to_most_local tn provenance; + match (tn.memory_mode, specifically) with + | None, _ -> assert false + | Some ((Virtual | Local | Device_only | On_device _), _), _ -> false + | Some (Hosted _, _), None -> true + | Some (Hosted memtyp, _), Some query -> equal_memory_type memtyp query + | Some ((Never_virtual | Materialized | Effectively_constant), _), _ -> assert false + let mem_properties (traced_store : Low_level.traced_store) = let cache = Hashtbl.create (module Tn) in fun tn -> diff --git a/lib/train.ml b/lib/train.ml index 6a395829..4bc4aca9 100644 --- a/lib/train.ml +++ b/lib/train.ml @@ -221,14 +221,14 @@ let%track3_sexp sequential_loop ~f lowered_bindings = remain at their initial values. [sync] is called after each round of calling all workers, and at the end if needed, with the number of workers called during the round. *) let%track3_sexp round_robin fs parallel_jitbs jitbs ~sync : unit = - let num_devices : int = Array.length fs in - assert (Array.length parallel_jitbs = num_devices); + let num_streams : int = Array.length fs in + assert (Array.length parallel_jitbs = num_streams); let pos = ref 0 in let rec loop = function | [] -> - fs.(!pos % num_devices) (); + fs.(!pos % num_streams) (); Int.incr pos; - if !pos % num_devices = 0 then sync num_devices + if !pos % num_streams = 0 then sync num_streams | ({ Idx.static_range = None; static_symbol = _ }, _) :: more -> loop more | (({ Idx.static_range = Some range; static_symbol = _ } as s), idx) :: ({ Idx.static_range = None; static_symbol = _ }, _) @@ -236,20 +236,20 @@ let%track3_sexp round_robin fs parallel_jitbs jitbs ~sync : unit = | (({ Idx.static_range = Some range; static_symbol = _ } as s), idx) :: more -> for i = 0 to range - 1 do idx := i; - if List.is_empty more then Idx.find_exn parallel_jitbs.(!pos % num_devices) s := i + if List.is_empty more then Idx.find_exn parallel_jitbs.(!pos % num_streams) s := i else Array.iter parallel_jitbs ~f:(fun jb -> Idx.find_exn jb s := i); loop more done in loop jitbs; - if !pos % num_devices <> 0 then sync (!pos % num_devices) + if !pos % num_streams <> 0 then sync (!pos % num_streams) -let%track3_sexp round_robin_dry_run ~num_devices jitbs ~dry_sync : unit = +let%track3_sexp round_robin_dry_run ~num_streams jitbs ~dry_sync : unit = let pos = ref 0 in let rec loop = function | [] -> Int.incr pos; - if !pos % num_devices = 0 then dry_sync num_devices + if !pos % num_streams = 0 then dry_sync num_streams | ({ Idx.static_range = None; static_symbol = _ }, _) :: more -> loop more | ({ Idx.static_range = Some range; static_symbol = _ }, idx) :: ({ Idx.static_range = None; static_symbol = _ }, _) @@ -261,7 +261,7 @@ let%track3_sexp round_robin_dry_run ~num_devices jitbs ~dry_sync : unit = done in loop jitbs; - if !pos % num_devices <> 0 then dry_sync (!pos % num_devices) + if !pos % num_streams <> 0 then dry_sync (!pos % num_streams) let set_virtual (a : Tn.t) = Tn.update_memory_mode a Virtual 29 @@ -333,16 +333,16 @@ let%track3_sexp parallel_update (type buffer_ptr dev runner event) and type event = event) ~(grad_updates : Backend.context BT.routine array) ~(sgd_update : Backend.context BT.routine) ~copy_to_merge ~post_sync updaten : unit -> unit = assert (not @@ Array.is_empty grad_updates); - let num_devices : int = Array.length grad_updates in + let num_streams : int = Array.length grad_updates in let bindings : Idx.static_symbol list = List.map ~f:fst sgd_update.bindings in let occupancies_dst_src = - Array.init num_devices ~f:(fun _ -> Array.create ~len:num_devices false) + Array.init num_streams ~f:(fun _ -> Array.create ~len:num_streams false) in (* to_, from positions correspond to the contexts (and devices) of grad_updates at the position. *) let dry_merge ~from ~to_ = occupancies_dst_src.(to_).(from) <- true in let dry_sync devices_to_sync = Arrayjit.Utils.parallel_merge dry_merge devices_to_sync in - round_robin_dry_run ~num_devices sgd_update.bindings ~dry_sync; + round_robin_dry_run ~num_streams sgd_update.bindings ~dry_sync; [%debug_notrace assert ( Array.for_all grad_updates ~f:(fun upd -> @@ -402,7 +402,7 @@ let%track3_sexp parallel_update (type buffer_ptr dev runner event) Array.iteri ctxs ~f:(fun i src -> if i <> 0 then merge_loss ~src); Set.iter !needed_on_host ~f:(fun p -> assert (Backend.to_host sgd_update.context p)); (* We will need to update params on all devices! Not only the ones that computed gradients. *) - for to_ = 1 to num_devices - 1 do + for to_ = 1 to num_streams - 1 do Array.iter all_params ~f:(fun p -> (* Allow the params to be shared across streams. *) ignore @@ -534,7 +534,9 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init Tn.log_accessible_headers ()); for epoch = 0 to epochs - 1 do epoch_loss := 0.; - Utils.capture_stdout_logs update; + (* DEBUG: *) + (* Utils.capture_stdout_logs *) + update (); learning_rates := learning_rate.@[0] :: !learning_rates; epoch_losses := !epoch_loss :: !epoch_losses; Option.iter per_epoch_callback ~f:(fun f -> @@ -571,7 +573,8 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init Tensor.set_values infer values; (* For the gccjit backend, infer is only on host, not on device. For cuda, this will be needed. *) - Utils.capture_stdout_logs @@ fun () -> + (* DEBUG: *) + (* Utils.capture_stdout_logs @@ fun () -> *) assert (Backend.from_host routine.context infer.value); run routine; assert (Backend.to_host routine.context model_result.value);