Skip to content

Commit

Permalink
Streams (except cuStreamWaitEvent and graph capture)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Jul 5, 2024
1 parent 1bb8faa commit 4efc74d
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 2 deletions.
4 changes: 2 additions & 2 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
## [0.3.0] current
## [0.3.0] 2024-07-05

### Added

- TODO: Support for streams.
- Support for streams (except `cuStreamWaitEvent` and graph capture).
- TODO: Support for asynchronous copying.

## [0.2.0] 2024-05-18
Expand Down
16 changes: 16 additions & 0 deletions cuda_ffi/bindings.ml
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,20 @@ module Functions (F : Ctypes.FOREIGN) = struct

let cu_ctx_set_limit = F.foreign "cuCtxSetLimit" F.(E.cu_limit @-> size_t @-> returning E.cu_result)
let cu_ctx_get_limit = F.foreign "cuCtxGetLimit" F.(ptr size_t @-> E.cu_limit @-> returning E.cu_result)

let cu_stream_attach_mem_async =
F.foreign "cuStreamAttachMemAsync"
F.(cu_stream @-> cu_deviceptr @-> size_t @-> uint @-> returning E.cu_result)

let cu_stream_create_with_priority =
F.foreign "cuStreamCreateWithPriority" F.(ptr cu_stream @-> uint @-> int @-> returning E.cu_result)

let cu_stream_destroy = F.foreign "cuStreamDestroy" F.(cu_stream @-> returning E.cu_result)

let cu_stream_get_ctx =
F.foreign "cuStreamGetCtx" F.(cu_stream @-> ptr cu_context @-> returning E.cu_result)

let cu_stream_get_id = F.foreign "cuStreamGetId" F.(cu_stream @-> ptr uint64_t @-> returning E.cu_result)
let cu_stream_query = F.foreign "cuStreamQuery" F.(cu_stream @-> returning E.cu_result)
let cu_stream_synchronize = F.foreign "cuStreamSynchronize" F.(cu_stream @-> returning E.cu_result)
end
33 changes: 33 additions & 0 deletions cuda_ffi/bindings_types.ml
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,16 @@ type cu_ctx_flags =
| CU_CTX_FLAGS_UNCATEGORIZED of int64
[@@deriving sexp]

type cu_mem_attach_flags =
| CU_MEM_ATTACH_GLOBAL
| CU_MEM_ATTACH_HOST
| CU_MEM_ATTACH_SINGLE
| CU_MEM_ATTACH_FLAGS_UNCATEGORIZED of int64
[@@deriving sexp]

type cu_stream_flags = CU_STREAM_DEFAULT | CU_STREAM_NON_BLOCKING | CU_STREAM_FLAGS_UNCATEGORIZED of int64
[@@deriving sexp]

module Types (T : Ctypes.TYPE) = struct
let cu_device_v1 = T.typedef T.int "CUdevice_v1"
let cu_device_t = T.typedef cu_device_v1 "CUdevice"
Expand Down Expand Up @@ -1369,4 +1379,27 @@ module Types (T : Ctypes.TYPE) = struct
(CU_CTX_SYNC_MEMOPS, cu_ctx_sync_memops);
(CU_CTX_FLAGS_MASK, cu_ctx_flags_mask);
]

let cu_mem_attach_global = T.constant "CU_MEM_ATTACH_GLOBAL" T.int64_t
let cu_mem_attach_host = T.constant "CU_MEM_ATTACH_HOST" T.int64_t
let cu_mem_attach_single = T.constant "CU_MEM_ATTACH_SINGLE" T.int64_t

let cu_mem_attach_flags =
T.enum ~typedef:true
~unexpected:(fun error_code -> CU_MEM_ATTACH_FLAGS_UNCATEGORIZED error_code)
"CUmemAttach_flags"
[
(CU_MEM_ATTACH_GLOBAL, cu_mem_attach_global);
(CU_MEM_ATTACH_HOST, cu_mem_attach_host);
(CU_MEM_ATTACH_SINGLE, cu_mem_attach_single);
]

let cu_stream_default = T.constant "CU_STREAM_DEFAULT" T.int64_t
let cu_stream_non_blocking = T.constant "CU_STREAM_NON_BLOCKING" T.int64_t

let cu_stream_flags =
T.enum ~typedef:true
~unexpected:(fun error_code -> CU_STREAM_FLAGS_UNCATEGORIZED error_code)
"CUstream_flags"
[ (CU_STREAM_DEFAULT, cu_stream_default); (CU_STREAM_NON_BLOCKING, cu_stream_non_blocking) ]
end
50 changes: 50 additions & 0 deletions cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,56 @@ let ctx_get_limit limit =
check "cu_ctx_set_limit" @@ Cuda.cu_ctx_get_limit value limit;
!@value

type attach_mem = Mem_global | Mem_host | Mem_single_stream [@@deriving sexp]

let uint_of_attach_mem f =
let open Cuda_ffi.Types_generated in
match f with
| Mem_global -> Unsigned.UInt.of_int64 cu_mem_attach_global
| Mem_host -> Unsigned.UInt.of_int64 cu_mem_attach_host
| Mem_single_stream -> Unsigned.UInt.of_int64 cu_mem_attach_single

let stream_attach_mem_async stream device length flag =
check "cu_stream_attach_mem_async"
@@ Cuda.cu_stream_attach_mem_async stream device (Unsigned.Size_t.of_int length)
@@ uint_of_attach_mem flag

let uint_of_cu_stream_flags ~non_blocking =
let open Cuda_ffi.Types_generated in
match non_blocking with
| false -> Unsigned.UInt.of_int64 cu_stream_default
| true -> Unsigned.UInt.of_int64 cu_stream_non_blocking

let stream_create ?(non_blocking = false) ?(lower_priority = 0) () =
let open Ctypes in
let stream = allocate_n cu_stream ~count:1 in
check "cu_stream_create_with_priority"
@@ Cuda.cu_stream_create_with_priority stream (uint_of_cu_stream_flags ~non_blocking) lower_priority;
!@stream

let stream_destroy stream = check "cu_stream_destroy" @@ Cuda.cu_stream_destroy stream

let stream_get_context stream =
let open Ctypes in
let ctx = allocate_n cu_context ~count:1 in
check "cu_stream_get_ctx" @@ Cuda.cu_stream_get_ctx stream ctx;
!@ctx

let stream_get_id stream =
let open Ctypes in
let id = allocate uint64_t Unsigned.UInt64.zero in
check "cu_stream_get_id" @@ Cuda.cu_stream_get_id stream id;
!@id

let stream_is_ready stream =
match Cuda.cu_stream_query stream with
| CUDA_ERROR_NOT_READY -> false
| e ->
check "cu_stream_query" e;
true

let stream_synchronize stream = check "cu_stream_synchronize" @@ Cuda.cu_stream_synchronize stream

type context = cu_context
type func = cu_function
type stream = cu_stream
Expand Down

0 comments on commit 4efc74d

Please sign in to comment.