Skip to content

Commit

Permalink
Implement __triton_launcher as pure DLL (#3251)
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored Jan 24, 2025
1 parent b018ed6 commit 303c0ab
Showing 1 changed file with 40 additions and 27 deletions.
67 changes: 40 additions & 27 deletions third_party/intel/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,35 @@ def __del__(self):
ctypes.windll.kernel32.FreeLibrary(handle)


class TritonLauncher:

def __init__(self, cache_path: str):
self.shared_library = ctypes.PyDLL(cache_path)
# breakpoint()
self.shared_library.launch.restype = ctypes.py_object
self.shared_library.launch.argtypes = (ctypes.py_object, )

def __getattribute__(self, name):
if name == "launch":
shared_library = super().__getattribute__("shared_library")
return getattr(shared_library, name)

return super().__getattribute__(name)

if os.name != 'nt':

def __del__(self):
handle = self.shared_library._handle
self.shared_library.dlclose.argtypes = (ctypes.c_void_p, )
self.shared_library.dlclose(handle)
else:

def __del__(self):
handle = self.shared_library._handle
ctypes.windll.kernel32.FreeLibrary.argtypes = (ctypes.c_uint64, )
ctypes.windll.kernel32.FreeLibrary(handle)


def compile_module_from_src(src, name):
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
Expand All @@ -192,6 +221,8 @@ def compile_module_from_src(src, name):

if name == 'arch_utils':
return ArchParser(cache_path)
elif name == '__triton_launcher':
return TritonLauncher(cache_path)

import importlib.util
spec = importlib.util.spec_from_file_location(name, cache_path)
Expand Down Expand Up @@ -339,6 +370,12 @@ def format_of(ty):
#include <sycl/sycl.hpp>
{ "#include <ATen/record_function.h>" if COMPILATION_HELPER.inject_pytorch_dep else "" }
#if defined(_WIN32)
#define EXPORT_FUNC __declspec(dllexport)
#else
#define EXPORT_FUNC __attribute__((visibility("default")))
#endif
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <Python.h>
#include <stdio.h>
Expand Down Expand Up @@ -466,8 +503,7 @@ def format_of(ty):
}}
// end sycl
static PyObject* launch(PyObject* self, PyObject* args) {{
extern "C" EXPORT_FUNC PyObject* launch(PyObject* args) {{
int gridX, gridY, gridZ;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
Expand Down Expand Up @@ -541,28 +577,6 @@ def format_of(ty):
Py_RETURN_NONE;
}}
static PyMethodDef ModuleMethods[] = {{
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
{{NULL, NULL, 0, NULL}} // sentinel
}};
static struct PyModuleDef ModuleDef = {{
PyModuleDef_HEAD_INIT,
\"__triton_launcher\",
NULL, //documentation
-1, //size
ModuleMethods
}};
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {{
return NULL;
}}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}}
"""
return src

Expand Down Expand Up @@ -635,15 +649,14 @@ def __init__(self, src, metadata):
self.constants = {arg_idx(idx): value for idx, value in constants.items()}
self.signature = {idx: value for idx, value in src.signature.items()}
src = make_launcher(self.constants, self.signature)
mod = compile_module_from_src(src, "__triton_launcher")
self.launch = mod.launch
self.mod = compile_module_from_src(src, "__triton_launcher")

def __call__(self, *args, **kwargs):
# Serialize KernelArguments for SPIR-V Runner
serialize_kernel_args = os.getenv('TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS', None)
if serialize_kernel_args:
serialize_args(args, self.constants, self.signature)
self.launch(*args, **kwargs)
self.mod.launch(args)


class XPUDriver(DriverBase):
Expand Down

0 comments on commit 303c0ab

Please sign in to comment.