Skip to content

Commit

Permalink
Undo keep_alive for stream; fix: forgot to remove finalizer from th…
Browse files Browse the repository at this point in the history
…e private `Event.create_event`
  • Loading branch information
lukstafi committed Nov 22, 2024
1 parent 3697242 commit d1ff0a5
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 29 deletions.
4 changes: 0 additions & 4 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
## [0.6.1] 2024-11-22

### Changed

- `Stream.create` has a new optional argument `keep_alive` to prevent e.g. finalizing the stream's context before the stream.

### Fixed

- Docu-comment typo.
Expand Down
33 changes: 11 additions & 22 deletions cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,6 @@ let destroy_event event = check "cu_event_destroy" @@ Cuda.cu_event_destroy even
let sexp_of_cu_stream (cu_stream : cu_stream) = sexp_of_voidp @@ Ctypes.to_voidp cu_stream

type stream = {
lifetime : (lifetime[@sexp.opaque]);
mutable args_lifetimes : (lifetime list[@sexp.opaque]);
mutable owned_events : delimited_event list;
stream : cu_stream;
Expand All @@ -1111,6 +1110,11 @@ let get_stream_context stream =

let set_current_context ctx = check "cu_ctx_set_current" @@ Cuda.cu_ctx_set_current ctx

let release_event event =
if not event.is_released then (
destroy_event event.event;
event.is_released <- true)

let release_stream stream =
stream.args_lifetimes <- [];
let ctx_unset = ref true in
Expand All @@ -1120,18 +1124,12 @@ let release_stream stream =
if !ctx_unset then (
set_current_context @@ get_stream_context stream;
ctx_unset := false);
destroy_event event.event;
event.is_released <- true))
release_event event))
stream.owned_events;
stream.owned_events <- []

let no_stream =
{
args_lifetimes = [];
owned_events = [];
stream = Ctypes.(coerce (ptr void) cu_stream null);
lifetime = Remember ();
}
{ args_lifetimes = []; owned_events = []; stream = Ctypes.(coerce (ptr void) cu_stream null) }

module Context = struct
type t = cu_context
Expand Down Expand Up @@ -1653,12 +1651,7 @@ module Stream = struct
[@@deriving sexp_of]

let no_stream =
{
args_lifetimes = [];
owned_events = [];
stream = Ctypes.(coerce (ptr void) cu_stream null);
lifetime = Remember ();
}
{ args_lifetimes = []; owned_events = []; stream = Ctypes.(coerce (ptr void) cu_stream null) }

let launch_kernel func ~grid_dim_x ?(grid_dim_y = 1) ?(grid_dim_z = 1) ~block_dim_x
?(block_dim_y = 1) ?(block_dim_z = 1) ~shared_mem_bytes stream kernel_params =
Expand Down Expand Up @@ -1711,15 +1704,14 @@ module Stream = struct
release_stream stream;
check "cu_stream_destroy" @@ Cuda.cu_stream_destroy stream.stream

let create ?keep_alive ?(non_blocking = false) ?(lower_priority = 0) () =
let create ?(non_blocking = false) ?(lower_priority = 0) () =
let open Ctypes in
let stream = allocate_n cu_stream ~count:1 in
check "cu_stream_create_with_priority"
@@ Cuda.cu_stream_create_with_priority stream
(uint_of_cu_stream_flags ~non_blocking)
lower_priority;
let lifetime = Remember keep_alive in
let stream = { args_lifetimes = []; owned_events = []; stream = !@stream; lifetime } in
let stream = { args_lifetimes = []; owned_events = []; stream = !@stream } in
Stdlib.Gc.finalise destroy stream;
stream

Expand Down Expand Up @@ -1810,7 +1802,6 @@ module Event = struct
@@ Cuda.cu_event_create event
(uint_of_cu_event_flags ~blocking_sync ~enable_timing ~interprocess);
let event = !@event in
Gc.finalise destroy event;
event

let create ?blocking_sync ?enable_timing ?interprocess () =
Expand Down Expand Up @@ -1865,9 +1856,7 @@ module Delimited_event = struct
let synchronize event =
if not event.is_released then (
Event.synchronize event.event;
if not event.is_released then (
destroy_event event.event;
event.is_released <- true))
release_event event)

let wait ?external_ stream event =
if not event.is_released then Event.wait ?external_ stream event.event
Expand Down
5 changes: 2 additions & 3 deletions cudajit.mli
Original file line number Diff line number Diff line change
Expand Up @@ -727,13 +727,12 @@ module Stream : sig
{{:https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g6e468d680e263e7eba02a56643c50533}
cuStreamAttachMemAsync}. *)

val create : ?keep_alive:'a -> ?non_blocking:bool -> ?lower_priority:int -> unit -> t
val create : ?non_blocking:bool -> ?lower_priority:int -> unit -> t
(** Lower [lower_priority] numbers represent higher priorities, the default is [0]. See
{{:https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g95c1a8c7c3dacb13091692dd9c7f7471}
cuStreamCreateWithPriority}.
[keep_alive] is kept in the returned value. One use-case is to prevent finalizing the stream's
context before the stream. The stream value is finalized using
The stream value is finalized using
{{:https://developer.download.nvidia.com/compute/DevZone/docs/html/C/doc/html/group__CUDA__STREAM_g244c8833de4596bcd31a06cdf21ee758.html}
cuStreamDestroy}. This is meant to be safe
{{:https://stackoverflow.com/questions/64663943/how-to-destroy-a-stream-that-was-created-on-a-specific-device}
Expand Down

0 comments on commit d1ff0a5

Please sign in to comment.