Skip to content

Commit

Permalink
Merge branch 'py38-windows-dll-path' into 'master'
Browse files Browse the repository at this point in the history
Try to discover CUDA binary path, to add to Py3.8's DLL path (Closes inducergh-213)

See merge request inducer/pycuda!17
  • Loading branch information
inducer committed Jul 11, 2019
2 parents 25532a0 + 27c281b commit edc50bb
Showing 1 changed file with 59 additions and 5 deletions.
64 changes: 59 additions & 5 deletions pycuda/driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,63 @@
from __future__ import absolute_import
from __future__ import print_function
from __future__ import absolute_import, print_function

import os
import sys

import six

import numpy as np


# {{{ add cuda lib dir to Python DLL path

def _search_on_path(filenames):
"""Find file on system path."""
# http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/52224

from os.path import exists, abspath, join
from os import pathsep, environ

search_path = environ["PATH"]

paths = search_path.split(pathsep)
for path in paths:
for filename in filenames:
if exists(join(path, filename)):
return abspath(join(path, filename))


def _add_cuda_libdir_to_dll_path():
from os.path import join, dirname

cuda_path = os.environ.get("CUDA_PATH")

if cuda_path is not None:
os.add_dll_directory(join(cuda_path, 'bin'))
return

nvcc_path = _search_on_path(["nvcc.exe"])
if nvcc_path is not None:
os.add_dll_directory(dirname(nvcc_path))

from warnings import warn
warn("Unable to discover CUDA installation directory "
"while attempting to add it to Python's DLL path. "
"Either set the 'CUDA_PATH' environment variable "
"or ensure that 'nvcc.exe' is on the path.")


try:
os.add_dll_directory
except AttributeError:
# likely not on Py3.8 and Windows
# https://github.com/inducer/pycuda/issues/213
pass
else:
_add_cuda_libdir_to_dll_path()

# }}}


try:
from pycuda._driver import * # noqa
except ImportError as e:
Expand All @@ -11,9 +68,6 @@
"does not match the version of your CUDA driver.")
raise

import numpy as np
import sys


if sys.version_info >= (3,):
_memoryview = memoryview
Expand Down

0 comments on commit edc50bb

Please sign in to comment.