Skip to content

Commit

Permalink
Work around Ctypes.typ_of_bigarray_kind because it does not support…
Browse files Browse the repository at this point in the history
… half precision
  • Loading branch information
lukstafi committed Sep 12, 2024
1 parent 8a87f01 commit f6d5821
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
2 changes: 1 addition & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
### Fixed

- Pass the $CUDA_PATH/include path to the nvrtc compiler; otherwise it will not `#include` anything.
- Work around `Ctypes.bigarray_start` because it does not support half precision.
- Work around `Ctypes.bigarray_start` and `typ_of_bigarray_kind` because `ctypes` does not support half precision.

## [0.4.0] 2024-07-21

Expand Down
33 changes: 21 additions & 12 deletions cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -392,18 +392,23 @@ let mem_alloc ~size_in_bytes =

let memcpy_H_to_D_impl ?host_offset ?length ~dst ~src memcpy =
let full_size = Bigarray.Genarray.size_in_bytes src in
let c_typ = Ctypes.typ_of_bigarray_kind @@ Bigarray.Genarray.kind src in
let elem_bytes = Bigarray.kind_size_in_bytes @@ Bigarray.Genarray.kind src in
let size_in_bytes =
match (host_offset, length) with
| None, None -> full_size
| Some _, None ->
invalid_arg "Cudajit.memcpy_H_to_D: providing offset requires providing length"
| _, Some length -> Ctypes.sizeof c_typ * length
| _, Some length -> elem_bytes * length
in
let open Ctypes in
let host = get_ptr_not_managed ~reftyp:c_typ src in
let host = match host_offset with None -> host | Some offset -> host +@ offset in
memcpy ~dst ~src:(coerce (ptr c_typ) (ptr void) host) ~size_in_bytes
let host =
match host_offset with
| None -> get_ptr_not_managed ~reftyp:void src
| Some offset ->
let host = get_ptr_not_managed ~reftyp:uint8_t src in
coerce (ptr uint8_t) (ptr void) @@ (host +@ (offset * elem_bytes))
in
memcpy ~dst ~src:host ~size_in_bytes

let memcpy_H_to_D_unsafe ~dst:(Deviceptr dst) ~(src : unit Ctypes.ptr) ~size_in_bytes =
check "cu_memcpy_H_to_D" @@ Cuda.cu_memcpy_H_to_D dst src @@ Unsigned.Size_t.of_int size_in_bytes
Expand Down Expand Up @@ -469,8 +474,7 @@ let ctx_synchronize () =

let memcpy_D_to_H_impl ?host_offset ?length ~dst ~src memcpy =
let full_size = Bigarray.Genarray.size_in_bytes dst in
let c_typ = Ctypes.typ_of_bigarray_kind @@ Bigarray.Genarray.kind dst in
let elem_bytes = Ctypes.sizeof c_typ in
let elem_bytes = Bigarray.kind_size_in_bytes @@ Bigarray.Genarray.kind dst in
let size_in_bytes =
match (host_offset, length) with
| None, None -> full_size
Expand All @@ -479,9 +483,15 @@ let memcpy_D_to_H_impl ?host_offset ?length ~dst ~src memcpy =
| Some offset, Some length -> elem_bytes * (length - offset)
in
let open Ctypes in
let host = get_ptr_not_managed ~reftyp:c_typ dst in
let host = match host_offset with None -> host | Some offset -> host +@ offset in
memcpy ~dst:(coerce (ptr c_typ) (ptr void) host) ~src ~size_in_bytes
let host =
match host_offset with
| None -> get_ptr_not_managed ~reftyp:void dst
| Some offset ->
let host = get_ptr_not_managed ~reftyp:uint8_t dst in
let host = host +@ (offset * elem_bytes) in
coerce (ptr uint8_t) (ptr void) host
in
memcpy ~dst:host ~src ~size_in_bytes

let memcpy_D_to_H_unsafe ~(dst : unit Ctypes.ptr) ~src:(Deviceptr src) ~size_in_bytes =
check "cu_memcpy_D_to_H" @@ Cuda.cu_memcpy_D_to_H dst src @@ Unsigned.Size_t.of_int size_in_bytes
Expand All @@ -500,8 +510,7 @@ let get_size_in_bytes ?kind ?length ?size_in_bytes provenance =
match (size_in_bytes, kind, length) with
| Some size, None, None -> size
| None, Some kind, Some length ->
let c_typ = Ctypes.typ_of_bigarray_kind kind in
let elem_bytes = Ctypes.sizeof c_typ in
let elem_bytes = Bigarray.kind_size_in_bytes kind in
elem_bytes * length
| Some _, Some _, Some _ ->
invalid_arg @@ provenance
Expand Down

0 comments on commit f6d5821

Please sign in to comment.