From 0ad175ab660693350ad350715ccf7c03021d7ea4 Mon Sep 17 00:00:00 2001 From: Lukasz Stafiniak Date: Mon, 16 Dec 2024 10:46:40 +0100 Subject: [PATCH] `Deviceptr.equal` and `Deviceptr.hash` --- CHANGES.md | 4 ++++ cudajit.ml | 8 ++++++-- cudajit.mli | 7 +++++++ test/saxpy.ml | 31 ++++++++++++++++++------------- 4 files changed, 35 insertions(+), 15 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 0be21af..11d070c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -4,6 +4,10 @@ - Documentation for `Deviceptr.mem_free`: mention it's safe to call multiple times on the same pointer. +### Added + +- `Deviceptr.equal` and `Deviceptr.hash`. + ## [0.6.1] 2024-12-02 ### Added diff --git a/cudajit.ml b/cudajit.ml index 936fa19..bd8e03b 100644 --- a/cudajit.ml +++ b/cudajit.ml @@ -1387,6 +1387,11 @@ let get_size_in_bytes ?kind ?length ?size_in_bytes provenance = module Deviceptr = struct type t = deviceptr [@@deriving sexp_of] + let equal (Deviceptr { ptr = ptr1; freed = _ }) (Deviceptr { ptr = ptr2; freed = _ }) = + Unsigned.UInt64.equal ptr1 ptr2 + + let hash (Deviceptr { ptr; freed = _ }) = Unsigned.UInt64.to_int ptr + let string_of (Deviceptr { ptr; freed }) = let addr = string_of_memptr ptr in if Atomic.get freed then addr ^ "/FREED" else addr @@ -1693,8 +1698,7 @@ module Stream = struct | Double of float [@@deriving sexp_of] - let no_stream = - { args_lifetimes = []; owned_events = []; stream = Ctypes.(coerce (ptr void) cu_stream null) } + let no_stream = no_stream let total_unreleased_unfinished_delimited_events stream = List.fold_left diff --git a/cudajit.mli b/cudajit.mli index c9e9414..8dec86a 100644 --- a/cudajit.mli +++ b/cudajit.mli @@ -413,6 +413,13 @@ module Deviceptr : sig {{:https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1g183f7b0d8ad008ea2a5fd552537ace4e} CUdeviceptr}. *) + val equal : t -> t -> bool + (** Compares the pointer values for equality. *) + + val hash : t -> int + (** Converts the pointer to an OCaml int using {!Unsigned.UInt64.to_int} (truncating bits as + needed). *) + val string_of : t -> string (** Hexadecimal representation of the pointer. *) diff --git a/test/saxpy.ml b/test/saxpy.ml index b010aa9..8794abf 100644 --- a/test/saxpy.ml +++ b/test/saxpy.ml @@ -41,6 +41,24 @@ let%expect_test "SAXPY" = let dY = Cu.Deviceptr.alloc_and_memcpy hY in let hOut = Host.create Bigarray.Float32 Bigarray.C_layout [| size |] in let dOut = Cu.Deviceptr.alloc_and_memcpy hOut in + let eq = Cu.Deviceptr.equal in + Printf.printf "dX = dX %b; dX = dY %b; dY = dOut %b.\n" (eq dX dX) (eq dX dY) (eq dY dOut); + [%expect + {| + cu_init + cu_device_get_count + cu_device_get + cu_ctx_create + cu_ctx_get_current + cu_module_load_data_ex + cu_module_get_function + cu_mem_alloc + cu_memcpy_H_to_D + cu_mem_alloc + cu_memcpy_H_to_D + cu_mem_alloc + cu_memcpy_H_to_D + dX = dX true; dX = dY false; dY = dOut false.|}]; Cu.Stream.launch_kernel kernel ~grid_dim_x:num_blocks ~block_dim_x:num_threads ~shared_mem_bytes:0 Cu.Stream.no_stream [ @@ -64,19 +82,6 @@ let%expect_test "SAXPY" = Gc.full_major (); [%expect {| - cu_init - cu_device_get_count - cu_device_get - cu_ctx_create - cu_ctx_get_current - cu_module_load_data_ex - cu_module_get_function - cu_mem_alloc - cu_memcpy_H_to_D - cu_mem_alloc - cu_memcpy_H_to_D - cu_mem_alloc - cu_memcpy_H_to_D cu_launch_kernel cu_ctx_synchronize cu_memcpy_D_to_H