Skip to content

Commit

Permalink
Work around Ctypes.bigarray_start because it does not support half …
Browse files Browse the repository at this point in the history
…precision
  • Loading branch information
lukstafi committed Sep 12, 2024
1 parent f5dfa12 commit 8a87f01
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### Fixed

- Pass the $CUDA_PATH/include path to the nvrtc compiler; otherwise it will not `#include` anything.
- Work around `Ctypes.bigarray_start` because it does not support half precision.

## [0.4.0] 2024-07-21

Expand Down
12 changes: 9 additions & 3 deletions cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,12 @@ let uint_of_cu_jit_cache_mode c =
| CU_JIT_CACHE_OPTION_CA -> Unsigned.UInt.of_int64 cu_jit_cache_option_ca
| CU_JIT_CACHE_OPTION_UNCATEGORIZED c -> Unsigned.UInt.of_int64 c

let bigarray_start_not_managed arr = Ctypes_bigarray.unsafe_address arr

let get_ptr_not_managed ~reftyp arr =
(* Work around because Ctypes.bigarray_start doesn't support half precision. *)
Ctypes_static.CPointer (Ctypes_memory.make_unmanaged ~reftyp @@ bigarray_start_not_managed arr)

let module_load_data_ex ptx options =
let open Ctypes in
let cu_mod = allocate_n cu_module ~count:1 in
Expand Down Expand Up @@ -335,7 +341,7 @@ let module_load_data_ex ptx options =
let f2vp f = coerce (ptr float) (ptr void) @@ allocate float f in
let i2vp i = coerce (ptr int) (ptr void) @@ allocate int i in
let bi2vp b = coerce (ptr int) (ptr void) @@ allocate int (if b then 1 else 0) in
let ba2vp b = coerce (ptr char) (ptr void) @@ bigarray_start Ctypes.array1 b in
let ba2vp b = get_ptr_not_managed ~reftyp:Ctypes_static.void b in
let c_opts_args =
CArray.of_list (ptr void)
@@ List.concat_map
Expand Down Expand Up @@ -395,7 +401,7 @@ let memcpy_H_to_D_impl ?host_offset ?length ~dst ~src memcpy =
| _, Some length -> Ctypes.sizeof c_typ * length
in
let open Ctypes in
let host = bigarray_start genarray src in
let host = get_ptr_not_managed ~reftyp:c_typ src in
let host = match host_offset with None -> host | Some offset -> host +@ offset in
memcpy ~dst ~src:(coerce (ptr c_typ) (ptr void) host) ~size_in_bytes

Expand Down Expand Up @@ -473,7 +479,7 @@ let memcpy_D_to_H_impl ?host_offset ?length ~dst ~src memcpy =
| Some offset, Some length -> elem_bytes * (length - offset)
in
let open Ctypes in
let host = bigarray_start genarray dst in
let host = get_ptr_not_managed ~reftyp:c_typ dst in
let host = match host_offset with None -> host | Some offset -> host +@ offset in
memcpy ~dst:(coerce (ptr c_typ) (ptr void) host) ~src ~size_in_bytes

Expand Down

0 comments on commit 8a87f01

Please sign in to comment.