Skip to content

Commit

Permalink
Remove Device.primary_ctx_release from the API, instead call it fro…
Browse files Browse the repository at this point in the history
…m `get_primary`'s finalizer;

clarify that `Context.create` is a usually a bad idea.
  • Loading branch information
lukstafi committed Sep 30, 2024
1 parent bfd84a5 commit 7ce920d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 21 deletions.
14 changes: 8 additions & 6 deletions cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1149,12 +1149,6 @@ module Context = struct
SYNC_MEMOPS;
]

let get_primary device =
let open Ctypes in
let ctx = allocate_n cu_context ~count:1 in
check "cu_device_primary_ctx_retain" @@ Cuda.cu_device_primary_ctx_retain ctx device;
!@ctx

let get_device () =
let open Ctypes in
let device = allocate Cuda_ffi.Types_generated.cu_device (Cu_device 0) in
Expand All @@ -1176,6 +1170,14 @@ module Context = struct
let push_current ctx = check "cu_ctx_push_current" @@ Cuda.cu_ctx_push_current ctx
let set_current ctx = check "cu_ctx_set_current" @@ Cuda.cu_ctx_set_current ctx

let get_primary device =
let open Ctypes in
let ctx = allocate_n cu_context ~count:1 in
check "cu_device_primary_ctx_retain" @@ Cuda.cu_device_primary_ctx_retain ctx device;
let ctx = !@ctx in
Stdlib.Gc.finalise (fun _ -> Device.primary_ctx_release device) ctx;
ctx

let synchronize () =
check "cu_ctx_synchronize" @@ Cuda.cu_ctx_synchronize ();
release_stream no_stream
Expand Down
25 changes: 10 additions & 15 deletions cudajit.mli
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,6 @@ module Device : sig
{{:https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g8bdd1cc7201304b01357b8034f6587cb}
cuDeviceGet}. *)

val primary_ctx_release : t -> unit
(** The context is automatically reset once the last reference to it is released. See
{{:https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PRIMARY__CTX.html#group__CUDA__PRIMARY__CTX_1gf2a8bc16f8df0c88031f6a1ba3d6e8ad}
cuDevicePrimaryCtxRelease}. *)

val primary_ctx_reset : t -> unit
(** Destroys all allocations and resets all state on the primary context. See
{{:https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PRIMARY__CTX.html#group__CUDA__PRIMARY__CTX_1g5d38802e8600340283958a117466ce12}
Expand Down Expand Up @@ -300,7 +295,8 @@ module Context : sig
CUcontext}. *)

val create : flags -> Device.t -> t
(** The context is pushed to the CPU-thread-local stack. See
(** NOTE: In most cases it is recommended to use {!get_primary} instead! The context is pushed to
the CPU-thread-local stack. See
{{:https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g65dc0012348bc84810e2103a40d8e2cf}
cuCtxCreate}
Expand All @@ -314,10 +310,14 @@ module Context : sig
cuCtxGetFlags}. *)

val get_primary : Device.t -> t
(** You should always call {!Device.primary_ctx_release} once done using the retained context. The
context is {i not} pushed to the stack. See
(** The context is {i not} pushed to the stack. See
{{:https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PRIMARY__CTX.html#group__CUDA__PRIMARY__CTX_1g9051f2d5c31501997a6cb0530290a300}
cuDevicePrimaryCtxRetain}. *)
cuDevicePrimaryCtxRetain}.
The context is finalized using
{{:https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PRIMARY__CTX.html#group__CUDA__PRIMARY__CTX_1gf2a8bc16f8df0c88031f6a1ba3d6e8ad}
cuDevicePrimaryCtxRelease}. The underlying CUDA context will be reset once the last
reference to it is released. *)

val get_device : unit -> Device.t
(** See
Expand Down Expand Up @@ -803,12 +803,7 @@ module Delimited_event : sig
cuEventDestroy}) when either it or its owner are synchronized (or if neither happens, when
it is garbage-collected). *)

val record :
?blocking_sync:bool ->
?interprocess:bool ->
?external_:bool ->
Stream.t ->
t
val record : ?blocking_sync:bool -> ?interprocess:bool -> ?external_:bool -> Stream.t -> t
(** Combines {!Event.create} and {!Event.record} to create an event owned by the given stream. *)

val is_released : t -> bool
Expand Down

0 comments on commit 7ce920d

Please sign in to comment.