From 9d7db9dd0769bf2573bc8012ef4e795f5ae7bd06 Mon Sep 17 00:00:00 2001
From: jakirkham <jakirkham@gmail.com>
Date: Mon, 13 Jan 2025 13:53:48 -0800
Subject: [PATCH] Switch to `pynvml_utils.smi` for PyNVML 12

---
 .../cugraph/standalone/bulk_sampling/bench_cugraph_training.py  | 2 +-
 python/utils/gpu_metric_poller.py                               | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py b/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py
index 2604642b748..7b3e7a6e1d0 100644
--- a/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py
+++ b/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py
@@ -36,7 +36,7 @@
 def init_pytorch_worker(rank: int, use_rmm_torch_allocator: bool = False) -> None:
     import cupy
     import rmm
-    from pynvml.smi import nvidia_smi
+    from pynvml_utils.smi import nvidia_smi
 
     smi = nvidia_smi.getInstance()
     pool_size = 16e9  # FIXME calculate this
diff --git a/python/utils/gpu_metric_poller.py b/python/utils/gpu_metric_poller.py
index 854552fb34f..8b02163fafc 100755
--- a/python/utils/gpu_metric_poller.py
+++ b/python/utils/gpu_metric_poller.py
@@ -31,7 +31,7 @@
 import os
 import sys
 import threading
-from pynvml import smi
+from pynvml_utils import smi
 
 
 class GPUMetricPoller(threading.Thread):