From 408033d04c12669e23df1129725e2d5776fe7fc3 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 21 Nov 2024 10:35:21 -0800 Subject: [PATCH] fix remat checkpoint of input --- axlearn/common/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 152c561be..5a8219738 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -3159,7 +3159,7 @@ def _forward_for_mode( Raises: ValueError: If `mode` is unsupported. """ - self._remat_name(data, "input") + data = self._remat_name(data, "input") self.vlog(3, "transformer.input=%s", data.sum()) self_attention_return_aux = set() cross_attention_return_aux = set()