From 814389e953af33d36ec1b06b6e0d6a854f7540d8 Mon Sep 17 00:00:00 2001 From: Haris Mahmood <70361308+hmahmood24@users.noreply.github.com> Date: Tue, 10 Oct 2023 00:25:56 -0700 Subject: [PATCH] fix: Update module converters to correctly handle tuple outputs in their `_forward` methods (#26737) --- ivy/stateful/module.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/ivy/stateful/module.py b/ivy/stateful/module.py index 303ade50f0306..68f1c425b7658 100644 --- a/ivy/stateful/module.py +++ b/ivy/stateful/module.py @@ -1078,9 +1078,8 @@ def _forward(self, *a, **kw): a, kw = ivy.args_to_native(*a, **kw) params_hk = self._dict_to_hk_flat_map(self.v.cont_to_dict()) ret = self._native_module.apply(params_hk, 0, *a, **kw) - if isinstance(ret, tuple): - return ivy.args_to_native(*ret) - return ivy.to_native(ret) + nested = True if isinstance(ret, tuple) else False + return ivy.to_native(ret, nested=nested) def _hk_flat_map_to_dict(self, hk_flat_map): from haiku._src.data_structures import FlatMapping @@ -1142,9 +1141,8 @@ def _forward(self, *a, **kw): a, kw = ivy.args_to_native(*a, **kw) params_fx = flax.core.freeze(self.v.cont_to_dict()) ret = self._native_module.apply(params_fx, *a, **kw) - if isinstance(ret, tuple): - return ivy.args_to_native(*ret) - return ivy.to_native(ret) + nested = True if isinstance(ret, tuple) else False + return ivy.to_native(ret, nested=nested) class _KerasIvyModule(Module): @@ -1169,9 +1167,8 @@ def _build(self, *args, **kwargs): def _forward(self, *a, **kw): a, kw = ivy.args_to_native(*a, **kw) ret = self._native_module(*a, **kw) - if isinstance(ret, tuple): - return ivy.args_to_native(*ret) - return ivy.to_native(ret) + nested = True if isinstance(ret, tuple) else False + return ivy.to_native(ret, nested=nested) class _PaddleIvyModule(Module): @@ -1201,9 +1198,8 @@ def _build(self, *args, **kwargs): def _forward(self, *a, **kw): a, kw = ivy.args_to_native(*a, **kw) ret = self._native_module(*a, **kw) - if isinstance(ret, tuple): - return ivy.args_to_native(*ret) - return ivy.to_native(ret) + nested = True if isinstance(ret, tuple) else False + return ivy.to_native(ret, nested=nested) class _TorchIvyModule(Module): @@ -1269,6 +1265,5 @@ def _forward(self, *a, **kw): a, kw = ivy.args_to_native(*a, **kw) self._update_v(self.v) ret = self._native_module(*a, **kw) - if isinstance(ret, tuple): - return ivy.args_to_native(*ret) - return ivy.to_native(ret) + nested = True if isinstance(ret, tuple) else False + return ivy.to_native(ret, nested=nested)