Skip to content

Commit

Permalink
Fix Pytorch inductor tests workflow since it doesn't work due to Trit…
Browse files Browse the repository at this point in the history
…on interface change in #3043 (#3080)

Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored Jan 4, 2025
1 parent 82af4c7 commit 8f1f00e
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 179 deletions.
1 change: 0 additions & 1 deletion scripts/patch-pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@ echo "Applying PyTorch patches in $REPO_ROOT"
cd "$REPO_ROOT"

curl -sSL https://github.com/pytorch/pytorch/pull/126516.diff | git apply -
# REVERT ME: it's just a trigger for pytorch rebuild
git apply "${SCRIPT_DIR}/pytorch.patch"
262 changes: 84 additions & 178 deletions scripts/pytorch.patch
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/test/inductor/test_codegen_triton.py b/test/inductor/test_codegen_triton.py
index 84264bf1b0..eb3fac0f39 100644
index 84264bf1b01..aa9a624d5ac 100644
--- a/test/inductor/test_codegen_triton.py
+++ b/test/inductor/test_codegen_triton.py
@@ -48,7 +48,7 @@ class TestCodegenTriton(InductorTestCase):
Expand All @@ -16,12 +16,37 @@ index 84264bf1b0..eb3fac0f39 100644

self.assertEqual(
- (0, 2, 4, 5, 6),
+ [(0, 2, 4, 5, 6)],
+ [(0,), (2,), (4,), (5,), (6,)],
_check_divisibility(
triton_utils.config_of(
[
diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py
index 4d7a85029e3..f3d45ea5520 100644
--- a/test/inductor/test_triton_kernels.py
+++ b/test/inductor/test_triton_kernels.py
@@ -1268,9 +1268,9 @@ def forward(self, x_1, output_1):
if dynamic:
# when half_n_elements passed to the Triton kernel is
# dynamic, equal_to_1 specializaiton can't be enforced
- self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0])
+ self.assertTrue(_triton_get_ast_equal_to_str([]) in sources[0])
else:
- self.assertTrue(_triton_get_ast_equal_to_str((3,)) in sources[0])
+ self.assertTrue(_triton_get_ast_equal_to_str([(3,)]) in sources[0])
self.assertEqual(compiled_out, eager_out)

@requires_gpu
@@ -1299,7 +1299,7 @@ def forward(self, x_1, output_1):

# float 1.0 (both literal or symbolic)
# should not be added to equal_to_1
- self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0])
+ self.assertTrue(_triton_get_ast_equal_to_str([]) in sources[0])
self.assertEqual(compiled_out, eager_out)

@requires_gpu
diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py
index ace56135fe..568cbde0a1 100644
index ace56135fe1..7e925dd6e45 100644
--- a/torch/_higher_order_ops/triton_kernel_wrap.py
+++ b/torch/_higher_order_ops/triton_kernel_wrap.py
@@ -238,7 +238,7 @@ def generate_ttir(
Expand All @@ -33,43 +58,41 @@ index ace56135fe..568cbde0a1 100644
except ImportError:
return kernel._get_config(*args)

@@ -251,7 +251,6 @@ def generate_ttir(
@@ -247,7 +247,8 @@ def generate_ttir(
name: arg for name, arg in ordered_args.items() if not isinstance(arg, Tensor)
}

- # Build kernel signature -- doesn't include constexpr arguments.
+ # Build kernel signature; it should also include `constexpr` arguments but `kernel._key_of`
+ # doesn't work correctly with them. They will be added in `fixup_signature` function later.
signature = {
name: kernel._type_of(kernel._key_of(arg))
for i, (name, arg) in enumerate(ordered_args.items())
- if i not in kernel.constexprs
}

@@ -257,7 +258,18 @@ def generate_ttir(
triton._C.libtriton.ir.load_dialects(context)
backend.load_dialects(context)

- src = ASTSource(kernel, signature, constants, specialization)
+ def fixup_signature(arg_names, signature, constants):
+ new_signature = {arg_name: None for arg_name in arg_names}
+ for arg_name in arg_names:
+ if arg_name in constants and arg_name not in signature:
+ # If it's not in the signature already, it's a constexpr
+ # argument that we need to add in
+ new_signature[arg_name] = "constexpr"
+ else:
+ new_signature[arg_name] = signature[arg_name]
+ return new_signature
+
+ src = ASTSource(kernel, fixup_signature(kernel.arg_names, signature, constants), constants, specialization)

# Triton changes ASTSource.make_ir to take 3/4 arguments. Handle
# backward compatibility here.
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 00031a56b8..b941e2aaa6 100644
index 00031a56b8d..59086d41b40 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -2980,6 +2980,7 @@ class TritonKernel(SIMDKernel):
code.splice(self.imports_for_benchmark_kernel())

argdefs, _, signature, _ = self.args.python_argdefs()
+ # breakpoint()
# maps actual expression to SizeArg if it is in sizevars replacements
for i, arg in enumerate(signature):
if isinstance(arg, SizeArg):
@@ -3030,7 +3031,7 @@ class TritonKernel(SIMDKernel):
triton_meta = {
"signature": triton_meta_signature,
"device": DeviceProperties.create(V.graph.get_current_device_or_throw()),
- "constants": {},
+ "constexprs": {},
}

# Skip memory optimization for forward of the training loop where we expect
@@ -3065,20 +3066,12 @@ class TritonKernel(SIMDKernel):
argdefs.append(f"{tree.prefix}numel")
# constexpr version causes issues, see
# https://github.com/pytorch/torchdynamo/pull/1362
- # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint(
+ # triton_meta["constexprs"][len(argdefs)] = V.graph.sizevars.size_hint(
# tree.numel
# )
@@ -3071,14 +3071,6 @@ class TritonKernel(SIMDKernel):
# argdefs.append(f"{tree.prefix}numel: tl.constexpr")
triton_meta["configs"] = [config_of(signature)]

Expand All @@ -84,175 +107,58 @@ index 00031a56b8..b941e2aaa6 100644
self.triton_meta = triton_meta

for tree in self.range_trees:
@@ -3087,9 +3080,14 @@ class TritonKernel(SIMDKernel):
continue
if tree.tensor_dim is None:
continue
- argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr")
+ const_name = f"{tree.prefix.upper()}BLOCK"
+ triton_meta['signature'][const_name] = 'constexpr'
+ triton_meta['constexprs'][const_name] = tree.numel
+ argdefs.append(f"{const_name} : tl.constexpr")

if self.cooperative_reduction:
+ triton_meta['signature']['RSPLIT'] = 'constexpr'
+ triton_meta['constexprs']['RSPLIT'] = tree.numel
argdefs.append("RSPLIT : tl.constexpr")

self.codegen_body()
diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py
index 8b8c29bbb1..3e5abaa824 100644
--- a/torch/_inductor/codegen/triton_utils.py
+++ b/torch/_inductor/codegen/triton_utils.py
@@ -157,13 +157,13 @@ def config_of(
raise NotImplementedError(f"unhandled {type(x)}: {x}")

if config.triton.divisible_by_16:
- divisible_by_16 = tuple(
+ divisible_by_16 = [tuple(
i
for i, arg in zip(indices, args)
if is_aligned(arg, alignment=16, include_tensor=True)
- )
+ )]
else:
- divisible_by_16 = ()
+ divisible_by_16 = []

equal_to_1 = tuple(
i
@@ -172,5 +172,7 @@ def config_of(
and isinstance(arg.expr, (int, sympy.Integer))
and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type]
)
+ if equal_to_1 != tuple():
+ equal_to_1 = [equal_to_1]

return AttrsDescriptorWrapper(divisible_by_16, equal_to_1)
diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
index 2ab2b32635..5f08c3c0b7 100644
index 2ab2b326354..42d76b8bf94 100644
--- a/torch/_inductor/codegen/wrapper.py
+++ b/torch/_inductor/codegen/wrapper.py
@@ -1535,16 +1535,21 @@ class PythonWrapperCodegen(CodeGen):

signature: List[KernelArgType] = []
constants: Dict[str, Any] = {}
+ constexprs = {}
non_constant_indices = []
equal_to_1_args: List[str] = []
for idx, key in enumerate(kernel.arg_names):
if key not in kwargs:
+ if idx in kernel.constexprs:
+ constexprs[key] = 'constexpr'
continue
arg = kwargs[key]
if idx in kernel.constexprs:
constants[key] = arg
+ constexprs[key] = 'constexpr'
elif kwargs[key] is None:
constants[key] = None
+ constexprs[key] = 'constexpr'
else:
non_constant_indices.append(idx)
if isinstance(arg, ir.TMADescriptor):
@@ -1596,9 +1601,8 @@ class PythonWrapperCodegen(CodeGen):
# causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input.
# https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
@@ -1598,7 +1598,6 @@ class PythonWrapperCodegen(CodeGen):
# https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
- "constants": {
+ "constexprs": {
"constants": {
**constants,
- **dict.fromkeys(equal_to_1_args, 1),
},
"configs": [
config_of(
@@ -1607,6 +1611,8 @@ class PythonWrapperCodegen(CodeGen):
)
],
}
+ for constexpr_name in constexprs.keys():
+ triton_meta['signature'][constexpr_name] = 'constexpr'

if restore_value_args:
triton_meta["restore_value"] = tuple(restore_value_args)
diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py
index 276c01f3f4..4e6e1ab9ce 100644
index 276c01f3f42..5c633b7963b 100644
--- a/torch/_inductor/runtime/hints.py
+++ b/torch/_inductor/runtime/hints.py
@@ -53,6 +53,7 @@ if _is_triton_available():
@@ -48,8 +48,8 @@ if _is_triton_available():
):
# Prepare the arguments for AttrsDescriptor
kwargs = {
- "tt.divisibility": divisible_by_16,
- "tt.equal_to": equal_to_1,
+ "tt.divisibility": tuple([(i,) for i in divisible_by_16]),
+ "tt.equal_to": tuple([(i,) for i in equal_to_1]),
}

# Instantiate AttrsDescriptor with the prepared arguments
+ # breakpoint()
res = AttrsDescriptor.from_dict(
{"arg_properties": kwargs, "cls": AttrsDescriptor.__name__}
)
diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py
index af8530e94d..a1935831e2 100644
index af8530e94d0..1ec44de9806 100644
--- a/torch/_inductor/runtime/triton_heuristics.py
+++ b/torch/_inductor/runtime/triton_heuristics.py
@@ -407,6 +407,7 @@ class CachingAutotuner(KernelInterface):

def _precompile_config(self, cfg: Config, warm_cache_only: bool):
"""Ahead of time compile a given autotuner config."""
+ # print(f"self.triton_meta: {self.triton_meta}")
compile_meta = copy.deepcopy(self.triton_meta)
for k, v in cfg.kwargs.items():
if self.device_props.type == "hip":
@@ -419,7 +420,7 @@ class CachingAutotuner(KernelInterface):
if k == "kpack":
compile_meta["kpack"] = v
continue
- compile_meta["constants"][k] = v
+ compile_meta["constexprs"][k] = v
compile_meta["num_warps"] = cfg.num_warps
compile_meta["num_stages"] = cfg.num_stages
compile_meta["debug"] = self.inductor_meta.get(
@@ -435,12 +436,13 @@ class CachingAutotuner(KernelInterface):
@@ -435,11 +435,22 @@ class CachingAutotuner(KernelInterface):
else:
triton_helpers.set_driver_to_gpu()

+ # print(compile_meta)
+ def fixup_signature(arg_names, signature, constants):
+ new_signature = {arg_name: None for arg_name in arg_names}
+ for arg_name in arg_names:
+ if arg_name in constants and arg_name not in signature:
+ # If it's not in the signature already, it's a constexpr
+ # argument that we need to add in
+ new_signature[arg_name] = "constexpr"
+ else:
+ new_signature[arg_name] = signature[arg_name]
+ return new_signature
+
if ASTSource:
compile_args = (
ASTSource(
self.fn,
compile_meta["signature"],
- compile_meta["constants"],
+ compile_meta["constexprs"],
- compile_meta["signature"],
+ fixup_signature(self.fn.arg_names, compile_meta["signature"], compile_meta["constants"]),
compile_meta["constants"],
compile_meta["configs"][0],
),
)
@@ -527,7 +529,7 @@ class CachingAutotuner(KernelInterface):
We also don't want to modify self.fn.

We know that we removed something from the signature if:
- 1. It's in compile_meta["constants"]
+ 1. It's in compile_meta["constexprs"]
2. It isn't a constant we already know about
Note: The value of interest has already been added to compile_meta['constants'],
so we use self.fn.constexprs instead.
@@ -538,7 +540,7 @@ class CachingAutotuner(KernelInterface):
}
none_args = {
k
- for k, v in compile_meta["constants"].items()
+ for k, v in compile_meta["constexprs"].items()
if v is None and k not in known_constants
}
none_args = none_args.difference(set(compile_meta["signature"].keys()))
@@ -548,12 +550,14 @@ class CachingAutotuner(KernelInterface):
for i, arg in enumerate(self.fn.arg_names)
if i not in self.fn.constexprs and arg not in none_args
]
+ # print(f"call_args: {call_args}")

def_args = [
name
for name in self.fn.arg_names
if name not in cfg.kwargs and name not in none_args
]
+ # print(f"def_args: {def_args}\n")
binary_shared = (
binary.shared if hasattr(binary, "shared") else binary.metadata.shared
)

0 comments on commit 8f1f00e

Please sign in to comment.