Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement __triton_launcher as pure DLL #3251

Merged
merged 1 commit into from
Jan 24, 2025
Merged

Implement __triton_launcher as pure DLL #3251

merged 1 commit into from
Jan 24, 2025

Conversation

anmyachev
Copy link
Contributor

@anmyachev anmyachev commented Jan 23, 2025

@anmyachev anmyachev marked this pull request as ready for review January 23, 2025 21:22
class TritonLauncher:

def __init__(self, cache_path: str):
self.shared_library = ctypes.PyDLL(cache_path)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Must use PyDLL instead of CDLL (as in d7d55b8) for python api to work correctly, otherwise I get segfaults.


def __call__(self, *args, **kwargs):
# Serialize KernelArguments for SPIR-V Runner
serialize_kernel_args = os.getenv('TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS', None)
if serialize_kernel_args:
serialize_args(args, self.constants, self.signature)
self.launch(*args, **kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the kwargs were not empty, the launcher would not start without issues, since according to its C signature it supports only args.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, the code unpacked both args and kwargs into a single dictionary and passed that to the launcher - now it looks like we're passing the args dict which seems incorrect as any values in kwargs would be ignored. Could we create a new dictionary that contains both args and kwargs and pass that to maintain the previous behavior?

Copy link
Contributor Author

@anmyachev anmyachev Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, the code unpacked both args and kwargs into a single dictionary and passed that to the launcher - now it looks like we're passing the args dict which seems incorrect as any values in kwargs would be ignored.

I tried to set the dictionary and the launcher immediately crashes. Are you sure that kwargs were ever used before?

>       self.launch(*args, **{"test": 2})
E       TypeError: launch() takes no keyword arguments

Seems only tuple is expected:

if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &py_obj_stream, &py_kernel,

Perhaps we need to make a cleanup in the triton code...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the calling classes support them: https://github.com/intel/intel-xpu-backend-for-triton/blob/main/python/triton/compiler/compiler.py#L432

In this line, all parameters will be packed into one tuple, since parameters are not passed as ..., name_param=param, ....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexbaden do you mind if I merge it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, interesting - I am sure your Python knowledge is better than mine! :)
Do you think this is something we should propose changing upstream? Or is this specific to our backend?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think this is something we should propose changing upstream? Or is this specific to our backend?

Similar code exists only for AMD backend. I suggested removing unused code in triton-lang/triton#5694.

Just for reference how it works for NVIDIA backend:

self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args)

@@ -635,15 +649,14 @@ def __init__(self, src, metadata):
self.constants = {arg_idx(idx): value for idx, value in constants.items()}
self.signature = {idx: value for idx, value in src.signature.items()}
src = make_launcher(self.constants, self.signature)
mod = compile_module_from_src(src, "__triton_launcher")
self.launch = mod.launch
self.mod = compile_module_from_src(src, "__triton_launcher")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep a reference to so/dll and not call the destructor prematurely.

@anmyachev anmyachev merged commit 303c0ab into main Jan 24, 2025
9 checks passed
@anmyachev anmyachev deleted the amyachev/3248 branch January 24, 2025 18:53
whitneywhtsang pushed a commit that referenced this pull request Jan 24, 2025
With unloaded DLL libraries, these changes are no longer necessary.
However, two tests that hold a reference to the compiled kernel need to
be adjusted - manually clear the cache (inside `JITFunction` object).

CI:
*
https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/12951798922
(passed)
*
https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/12955820922
(check status)

Blocked on #3251

Extra refs:
* python/cpython#87319

---------

Signed-off-by: Anatoly Myachev <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ImportError: DLL load failed while importing __triton_launcher: The parameter is incorrect
3 participants