Skip to content

Commit

Permalink
Update GumbelSoftmaxTransformation tau value and add to_dict method
Browse files Browse the repository at this point in the history
  • Loading branch information
BirkhoffG committed Apr 3, 2024
1 parent 7984e05 commit f8b5618
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
5 changes: 4 additions & 1 deletion nbs/01_data.utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@
"class GumbelSoftmaxTransformation(_OneHotTransformation):\n",
" \"\"\"Apply Gumbel softmax tricks for categorical transformation.\"\"\"\n",
"\n",
" def __init__(self, tau: float = 1.):\n",
" def __init__(self, tau: float = .1):\n",
" super().__init__(\"gumbel\")\n",
" self.tau = tau\n",
" \n",
Expand All @@ -560,6 +560,9 @@
" \"\"\"Apply constraints to the counterfactuals. If `rng_key` is None, no randomness is used.\"\"\"\n",
" return super().apply_constraints(xs, cfs, hard, rng_key, **kwargs)\n",
" \n",
" def to_dict(self) -> dict:\n",
" return super().to_dict() | {\"tau\": self.tau}\n",
" \n",
"def OneHotTransformation():\n",
" warnings.warn(\"OneHotTransformation is deprecated since v0.2.5. \"\n",
" \"Use `SoftmaxTransformation` (same functionality) \"\n",
Expand Down
2 changes: 2 additions & 0 deletions relax/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@
'relax/data_utils.py'),
'relax.data_utils.GumbelSoftmaxTransformation.soft_constraints': ( 'data.utils.html#gumbelsoftmaxtransformation.soft_constraints',
'relax/data_utils.py'),
'relax.data_utils.GumbelSoftmaxTransformation.to_dict': ( 'data.utils.html#gumbelsoftmaxtransformation.to_dict',
'relax/data_utils.py'),
'relax.data_utils.IdentityTransformation': ( 'data.utils.html#identitytransformation',
'relax/data_utils.py'),
'relax.data_utils.IdentityTransformation.__init__': ( 'data.utils.html#identitytransformation.__init__',
Expand Down
5 changes: 4 additions & 1 deletion relax/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):
class GumbelSoftmaxTransformation(_OneHotTransformation):
"""Apply Gumbel softmax tricks for categorical transformation."""

def __init__(self, tau: float = 1.):
def __init__(self, tau: float = .1):
super().__init__("gumbel")
self.tau = tau

Expand All @@ -267,6 +267,9 @@ def apply_constraints(self, xs, cfs, hard: bool = False, rng_key=None, **kwargs)
"""Apply constraints to the counterfactuals. If `rng_key` is None, no randomness is used."""
return super().apply_constraints(xs, cfs, hard, rng_key, **kwargs)

def to_dict(self) -> dict:
return super().to_dict() | {"tau": self.tau}

def OneHotTransformation():
warnings.warn("OneHotTransformation is deprecated since v0.2.5. "
"Use `SoftmaxTransformation` (same functionality) "
Expand Down

0 comments on commit f8b5618

Please sign in to comment.