Skip to content

Commit

Permalink
Merge pull request #42 from BirkhoffG/patch
Browse files Browse the repository at this point in the history
Fix `GumbelSoftmaxTransformation.name`; add more tests for `set_transformation`
  • Loading branch information
BirkhoffG authored Feb 22, 2024
2 parents 8f3ec4c + f89c059 commit 7984e05
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 64 deletions.
105 changes: 44 additions & 61 deletions nbs/01_data.utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,15 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
}
],
"source": [
"#| export\n",
"from __future__ import annotations\n",
Expand Down Expand Up @@ -493,8 +501,8 @@
"source": [
"#| export\n",
"class _OneHotTransformation(Transformation):\n",
" def __init__(self):\n",
" super().__init__(\"ohe\", OneHotEncoder())\n",
" def __init__(self, name: str = None):\n",
" super().__init__(name, OneHotEncoder())\n",
"\n",
" @property\n",
" def num_categories(self) -> int:\n",
Expand Down Expand Up @@ -528,6 +536,9 @@
"source": [
"#| export\n",
"class SoftmaxTransformation(_OneHotTransformation):\n",
" def __init__(self): \n",
" super().__init__(\"ohe\")\n",
"\n",
" def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):\n",
" x, rng_key, kwargs = operand\n",
" return jax.nn.softmax(x, axis=-1)\n",
Expand All @@ -536,7 +547,7 @@
" \"\"\"Apply Gumbel softmax tricks for categorical transformation.\"\"\"\n",
"\n",
" def __init__(self, tau: float = 1.):\n",
" super().__init__()\n",
" super().__init__(\"gumbel\")\n",
" self.tau = tau\n",
" \n",
" def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):\n",
Expand Down Expand Up @@ -1155,71 +1166,43 @@
"metadata": {},
"outputs": [],
"source": [
"# Test set_transformations\n",
"feats_list_2 = deepcopy(feats_list)\n",
"feats_list_2.set_transformations({\n",
" feat: 'ordinal' for feat in cat_feats\n",
"})\n",
"assert feats_list_2.transformed_data.shape == (32561, 8)\n",
"def test_set_transformations(transformation, correct_shape):\n",
" T = transformation\n",
" feats_list_2 = deepcopy(feats_list)\n",
" feats_list_2.set_transformations({\n",
" feat: T for feat in cat_feats\n",
" })\n",
" assert feats_list_2.transformed_data.shape == correct_shape\n",
" name = T.name if isinstance(T, Transformation) else T\n",
"\n",
"for feat in feats_list_2:\n",
" if feat.name in cat_feats: \n",
" assert feat.transformation.name == 'ordinal'\n",
" assert feat.is_categorical\n",
" else:\n",
" assert feat.transformation.name == 'minmax' \n",
" assert feat.is_categorical is False\n",
" assert feat.is_immutable is False\n",
"del feats_list_2"
" for feat in feats_list_2:\n",
" if feat.name in cat_feats: \n",
" assert feat.transformation.name == name\n",
" assert feat.is_categorical\n",
" else:\n",
" assert feat.transformation.name == 'minmax' \n",
" assert feat.is_categorical is False\n",
" assert feat.is_immutable is False\n",
"\n",
" x = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, correct_shape[-1]))\n",
" _ = feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=False)\n",
" _ = feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array([[0.22846544, 0.1524936 , 0. , ..., 1. , 1. ,\n",
" 0. ],\n",
" [0.17293715, 0.5796174 , 0. , ..., 0. , 0. ,\n",
" 1. ],\n",
" [0.17434704, 0.8137592 , 0. , ..., 1. , 1. ,\n",
" 0. ],\n",
" ...,\n",
" [0.68356454, 0.65396845, 0. , ..., 0. , 1. ,\n",
" 0. ],\n",
" [0.73027587, 0.4722154 , 1. , ..., 1. , 0. ,\n",
" 1. ],\n",
" [0.8495003 , 0.04826355, 1. , ..., 1. , 0. ,\n",
" 1. ]], dtype=float32)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# Test set_transformations\n",
"feats_list_2 = FeaturesList.from_dict(feats_list.to_dict())\n",
"feats_list_2.set_transformations({\n",
" feat: OneHotTransformation() for feat in cat_feats\n",
"})\n",
"assert feats_list_2.transformed_data.shape == (32561, 29)\n",
"\n",
"for feat in feats_list_2:\n",
" if feat.name in cat_feats: \n",
" assert feat.transformation.name == 'ohe'\n",
" assert feat.is_categorical\n",
" else:\n",
" assert feat.transformation.name == 'minmax' \n",
" assert feat.is_categorical is False\n",
" assert feat.is_immutable is False\n",
"\n",
"x = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, 29))\n",
"feats_list_2.apply_constraints(feats_list_2.transformed_data[:100], x, hard=True)"
"test_set_transformations('ordinal', (32561, 8))\n",
"test_set_transformations('ohe', (32561, 29))\n",
"test_set_transformations('gumbel', (32561, 29))\n",
"# TODO: [bug] raise error when set_transformations is called with \n",
"# SoftmaxTransformation() or GumbelSoftmaxTransformation(),\n",
"# instead of \"ohe\" or \"gumbel\".\n",
"# test_set_transformations(SoftmaxTransformation(), (32561, 29))\n",
"# test_set_transformations(GumbelSoftmaxTransformation(), (32561, 29))"
]
},
{
Expand Down
2 changes: 2 additions & 0 deletions relax/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@
'relax/data_utils.py'),
'relax.data_utils.SoftmaxTransformation': ( 'data.utils.html#softmaxtransformation',
'relax/data_utils.py'),
'relax.data_utils.SoftmaxTransformation.__init__': ( 'data.utils.html#softmaxtransformation.__init__',
'relax/data_utils.py'),
'relax.data_utils.SoftmaxTransformation.soft_constraints': ( 'data.utils.html#softmaxtransformation.soft_constraints',
'relax/data_utils.py'),
'relax.data_utils.Transformation': ('data.utils.html#transformation', 'relax/data_utils.py'),
Expand Down
9 changes: 6 additions & 3 deletions relax/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def apply_constraints(self, xs, cfs, **kwargs):

# %% ../nbs/01_data.utils.ipynb 23
class _OneHotTransformation(Transformation):
def __init__(self):
super().__init__("ohe", OneHotEncoder())
def __init__(self, name: str = None):
super().__init__(name, OneHotEncoder())

@property
def num_categories(self) -> int:
Expand All @@ -243,6 +243,9 @@ def compute_reg_loss(self, xs, cfs, hard: bool = False):

# %% ../nbs/01_data.utils.ipynb 24
class SoftmaxTransformation(_OneHotTransformation):
def __init__(self):
super().__init__("ohe")

def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):
x, rng_key, kwargs = operand
return jax.nn.softmax(x, axis=-1)
Expand All @@ -251,7 +254,7 @@ class GumbelSoftmaxTransformation(_OneHotTransformation):
"""Apply Gumbel softmax tricks for categorical transformation."""

def __init__(self, tau: float = 1.):
super().__init__()
super().__init__("gumbel")
self.tau = tau

def soft_constraints(self, operand: tuple[jax.Array, jax.random.PRNGKey, dict]):
Expand Down

0 comments on commit 7984e05

Please sign in to comment.