Skip to content

Commit

Permalink
fix: Update module converters to correctly handle tuple outputs in th…
Browse files Browse the repository at this point in the history
…eir `_forward` methods (#26737)
  • Loading branch information
hmahmood24 authored Oct 10, 2023
1 parent a7d8582 commit 814389e
Showing 1 changed file with 10 additions and 15 deletions.
25 changes: 10 additions & 15 deletions ivy/stateful/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,9 +1078,8 @@ def _forward(self, *a, **kw):
a, kw = ivy.args_to_native(*a, **kw)
params_hk = self._dict_to_hk_flat_map(self.v.cont_to_dict())
ret = self._native_module.apply(params_hk, 0, *a, **kw)
if isinstance(ret, tuple):
return ivy.args_to_native(*ret)
return ivy.to_native(ret)
nested = True if isinstance(ret, tuple) else False
return ivy.to_native(ret, nested=nested)

def _hk_flat_map_to_dict(self, hk_flat_map):
from haiku._src.data_structures import FlatMapping
Expand Down Expand Up @@ -1142,9 +1141,8 @@ def _forward(self, *a, **kw):
a, kw = ivy.args_to_native(*a, **kw)
params_fx = flax.core.freeze(self.v.cont_to_dict())
ret = self._native_module.apply(params_fx, *a, **kw)
if isinstance(ret, tuple):
return ivy.args_to_native(*ret)
return ivy.to_native(ret)
nested = True if isinstance(ret, tuple) else False
return ivy.to_native(ret, nested=nested)


class _KerasIvyModule(Module):
Expand All @@ -1169,9 +1167,8 @@ def _build(self, *args, **kwargs):
def _forward(self, *a, **kw):
a, kw = ivy.args_to_native(*a, **kw)
ret = self._native_module(*a, **kw)
if isinstance(ret, tuple):
return ivy.args_to_native(*ret)
return ivy.to_native(ret)
nested = True if isinstance(ret, tuple) else False
return ivy.to_native(ret, nested=nested)


class _PaddleIvyModule(Module):
Expand Down Expand Up @@ -1201,9 +1198,8 @@ def _build(self, *args, **kwargs):
def _forward(self, *a, **kw):
a, kw = ivy.args_to_native(*a, **kw)
ret = self._native_module(*a, **kw)
if isinstance(ret, tuple):
return ivy.args_to_native(*ret)
return ivy.to_native(ret)
nested = True if isinstance(ret, tuple) else False
return ivy.to_native(ret, nested=nested)


class _TorchIvyModule(Module):
Expand Down Expand Up @@ -1269,6 +1265,5 @@ def _forward(self, *a, **kw):
a, kw = ivy.args_to_native(*a, **kw)
self._update_v(self.v)
ret = self._native_module(*a, **kw)
if isinstance(ret, tuple):
return ivy.args_to_native(*ret)
return ivy.to_native(ret)
nested = True if isinstance(ret, tuple) else False
return ivy.to_native(ret, nested=nested)

0 comments on commit 814389e

Please sign in to comment.