From 56bd75edbd524e6c6738ca7f56bf42855169c648 Mon Sep 17 00:00:00 2001
From: jrzaurin
Date: Thu, 11 Feb 2021 12:08:31 +0100
Subject: [PATCH] Reviewed all the writting. Docs, README and examples. Adusted
and refine types. Ready to publish
---
README.md | 25 ++---
docs/index.rst | 4 +-
docs/model_components.rst | 6 +-
docs/trainer.rst | 2 +-
docs/utils/fastai_transforms.rst | 7 +-
..._Binary_Classification_with_Defaults.ipynb | 2 +-
...FineTune_and_WarmUp_Model_Components.ipynb | 4 +-
pypi_README.md | 13 +--
pytorch_widedeep/models/deep_image.py | 22 ++--
pytorch_widedeep/models/deep_text.py | 24 ++--
pytorch_widedeep/models/tab_mlp.py | 10 +-
pytorch_widedeep/models/tab_resnet.py | 42 +++----
pytorch_widedeep/models/tab_transformer.py | 25 +++--
pytorch_widedeep/models/wide_deep.py | 50 ++++-----
.../preprocessing/preprocessors.py | 16 ++-
pytorch_widedeep/training/trainer.py | 105 +++++++++++-------
pytorch_widedeep/utils/deeptabular_utils.py | 2 +-
pytorch_widedeep/utils/fastai_transforms.py | 19 ++--
.../test_initializers.py | 4 +-
19 files changed, 202 insertions(+), 180 deletions(-)
diff --git a/README.md b/README.md
index d3242196..4ee90ba5 100644
--- a/README.md
+++ b/README.md
@@ -22,10 +22,8 @@ using wide and deep models.
### Introduction
-`pytorch-widedeep` is based on Google's Wide and Deep Algorithm. Details of
-the original algorithm can be found
-[here](https://www.tensorflow.org/tutorials/wide_and_deep), and the research
-paper can be found [here](https://arxiv.org/abs/1606.07792).
+`pytorch-widedeep` is based on Google's Wide and Deep Algorithm, [Wide & Deep
+Learning for Recommender Systems](https://arxiv.org/abs/1606.07792).
In general terms, `pytorch-widedeep` is a package to use deep learning with
tabular data. In particular, is intended to facilitate the combination of text
@@ -40,18 +38,20 @@ in the Figure below:
The dashed boxes in the figure represent optional, overall components, and the
-dashed lines indicate the corresponding connections, depending on whether or
-not certain components are present. For example, the dashed, blue-lines
-indicate that the ``deeptabular``, ``deeptext`` and ``deepimage`` components
-are connected directly to the output neuron or neurons (depending on whether
-we are performing a binary classification or regression, or a multi-class
-classification) if the optional ``deephead`` is not present. The components
-within the faded-pink rectangle are concatenated.
+dashed lines/arrows indicate the corresponding connections, depending on
+whether or not certain components are present. For example, the dashed,
+blue-lines indicate that the ``deeptabular``, ``deeptext`` and ``deepimage``
+components are connected directly to the output neuron or neurons (depending
+on whether we are performing a binary classification or regression, or a
+multi-class classification) if the optional ``deephead`` is not present.
+Finally, the components within the faded-pink rectangle are concatenated.
Note that it is not possible to illustrate the number of possible
architectures and components available in ``pytorch-widedeep`` in one Figure.
Therefore, for more details on possible architectures (and more) please, see
-the documentation, or the Examples folders and the notebooks there.
+the
+[documentation]((https://pytorch-widedeep.readthedocs.io/en/latest/index.html)),
+or the Examples folders and the notebooks there.
In math terms, and following the notation in the
[paper](https://arxiv.org/abs/1606.07792), the expression for the architecture
@@ -187,7 +187,6 @@ from pytorch_widedeep.metrics import Accuracy
# the following 4 lines are not directly related to ``pytorch-widedeep``. I
# assume you have downloaded the dataset and place it in a dir called
# data/adult/
-
df = pd.read_csv("data/adult/adult.csv.zip")
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop("income", axis=1, inplace=True)
diff --git a/docs/index.rst b/docs/index.rst
index cde440fe..3d762cbc 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -53,7 +53,9 @@ within the faded-pink rectangle are concatenated.
Note that it is not possible to illustrate the number of possible
architectures and components available in ``pytorch-widedeep`` in one Figure.
Therefore, for more details on possible architectures (and more) please, read
-this documentation, or see the Examples folders in the repo.
+this documentation, or see the `Examples
+`_ folders
+in the repo.
In math terms, and following the notation in the `paper
`_, the expression for the architecture
diff --git a/docs/model_components.rst b/docs/model_components.rst
index 940b7b14..cd968440 100644
--- a/docs/model_components.rst
+++ b/docs/model_components.rst
@@ -1,9 +1,9 @@
The ``models`` module
======================
-This module contains the four main components that will comprise Wide and Deep
-model, and the ``WideDeep`` "constructor" class. These four components are:
-``wide``, ``deeptabular``, ``deeptext``, ``deepimage``.
+This module contains the four main components that will comprise a Wide and
+Deep model, and the ``WideDeep`` "constructor" class. These four components
+are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``.
.. note:: ``TabMlp``, ``TabResnet`` and ``TabTransformer`` can all be used
as the ``deeptabular`` component of the model and simply represent
diff --git a/docs/trainer.rst b/docs/trainer.rst
index de27b296..7483c9bb 100644
--- a/docs/trainer.rst
+++ b/docs/trainer.rst
@@ -3,7 +3,7 @@ Training wide and deep models for tabular data
`...` or just deep learning models for tabular data.
-Here is the documentation for the ``Trainer`` class, that will do all the heavy lifting
+Here is the documentation for the ``Trainer`` class, that will do all the heavy lifting.
Trainer is also available from ``pytorch-widedeep`` directly, for example, one could do:
diff --git a/docs/utils/fastai_transforms.rst b/docs/utils/fastai_transforms.rst
index 54a89e1f..f18d60b0 100644
--- a/docs/utils/fastai_transforms.rst
+++ b/docs/utils/fastai_transforms.rst
@@ -5,9 +5,10 @@ I have directly copied and pasted part of the ``transforms.py`` module from
the ``fastai`` library. The reason to do such a thing is because
``pytorch_widedeep`` only needs the ``Tokenizer`` and the ``Vocab`` classes
there. This way I avoid extra dependencies. Credit for all the code in the
-``fastai_transforms`` module to Jeremy Howard and the `fastai` team. I only
-include the documentation here for completion, but I strongly advise the user
-to read the ``fastai`` `documentation `_.
+``fastai_transforms`` module in this ``pytorch-widedeep`` package goes to
+Jeremy Howard and the `fastai` team. I only include the documentation here for
+completion, but I strongly advise the user to read the ``fastai``
+`documentation `_.
.. autoclass:: pytorch_widedeep.utils.fastai_transforms.Tokenizer
:members:
diff --git a/examples/03_Binary_Classification_with_Defaults.ipynb b/examples/03_Binary_Classification_with_Defaults.ipynb
index a506aa98..d0428449 100644
--- a/examples/03_Binary_Classification_with_Defaults.ipynb
+++ b/examples/03_Binary_Classification_with_Defaults.ipynb
@@ -859,7 +859,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "The only requisite is that the model component must be passed to `WideDeep` before \"fed\" to the `Trainer`. This is because the `Trainer` is coded so that it trains a model that has a parent called `model` and then children that correspond to the model components: `wide`, `deeptabular`, `deeptext` and `deepimage`. "
+ "The only requisite is that the model component must be passed to `WideDeep` before \"fed\" to the `Trainer`. This is because the `Trainer` is coded so that it trains a model that has a parent called `model` and then children that correspond to the model components: `wide`, `deeptabular`, `deeptext` and `deepimage`. Also, `WideDeep` builds the last connection between the output of those components and the final, output neuron(s)."
]
}
],
diff --git a/examples/06_FineTune_and_WarmUp_Model_Components.ipynb b/examples/06_FineTune_and_WarmUp_Model_Components.ipynb
index ea1e0075..6564f5d1 100644
--- a/examples/06_FineTune_and_WarmUp_Model_Components.ipynb
+++ b/examples/06_FineTune_and_WarmUp_Model_Components.ipynb
@@ -331,7 +331,7 @@
"metadata": {},
"outputs": [],
"source": [
- "trainer2 = Trainer(model=\"models_dir/model1.t\", objective=\"binary\", metrics=[Accuracy])"
+ "trainer2 = Trainer(model_path=\"models_dir/model1.t\", objective=\"binary\", metrics=[Accuracy])"
]
},
{
@@ -949,7 +949,7 @@
"metadata": {},
"outputs": [],
"source": [
- "trainer7 = Trainer(model=\"models_dir/model3.t\", objective=\"binary\", metrics=[Accuracy])"
+ "trainer7 = Trainer(model_path=\"models_dir/model3.t\", objective=\"binary\", metrics=[Accuracy])"
]
},
{
diff --git a/pypi_README.md b/pypi_README.md
index 7510ead4..584369f6 100644
--- a/pypi_README.md
+++ b/pypi_README.md
@@ -17,10 +17,8 @@ using wide and deep models.
### Introduction
-`pytorch-widedeep` is based on Google's Wide and Deep Algorithm. Details of
-the original algorithm can be found
-[here](https://www.tensorflow.org/tutorials/wide_and_deep), and the research
-paper can be found [here](https://arxiv.org/abs/1606.07792).
+`pytorch-widedeep` is based on Google's Wide and Deep Algorithm, [Wide & Deep
+Learning for Recommender Systems](https://arxiv.org/abs/1606.07792).
In general terms, `pytorch-widedeep` is a package to use deep learning with
tabular data. In particular, is intended to facilitate the combination of text
@@ -84,12 +82,6 @@ Binary classification with the [adult
dataset]([adult](https://www.kaggle.com/wenruliu/adult-income-dataset))
using `Wide` and `DeepDense` and defaults settings.
-### Quick start
-
-Binary classification with the [adult
-dataset]([adult](https://www.kaggle.com/wenruliu/adult-income-dataset))
-using `Wide` and `DeepDense` and defaults settings.
-
```python
```
@@ -110,7 +102,6 @@ from pytorch_widedeep.metrics import Accuracy
# the following 4 lines are not directly related to ``pytorch-widedeep``. I
# assume you have downloaded the dataset and place it in a dir called
# data/adult/
-
df = pd.read_csv("data/adult/adult.csv.zip")
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop("income", axis=1, inplace=True)
diff --git a/pytorch_widedeep/models/deep_image.py b/pytorch_widedeep/models/deep_image.py
index 45bf59e1..4938bac4 100644
--- a/pytorch_widedeep/models/deep_image.py
+++ b/pytorch_widedeep/models/deep_image.py
@@ -32,11 +32,11 @@ def __init__(
resnet_architecture: int = 18,
freeze_n: int = 6,
head_hidden_dims: Optional[List[int]] = None,
- head_activation: Optional[str] = "relu",
- head_dropout: Optional[float] = None,
- head_batchnorm: Optional[bool] = False,
- head_batchnorm_last: Optional[bool] = False,
- head_linear_first: Optional[bool] = False,
+ head_activation: str = "relu",
+ head_dropout: float = 0.1,
+ head_batchnorm: bool = False,
+ head_batchnorm_last: bool = False,
+ head_linear_first: bool = False,
):
r"""
Standard image classifier/regressor using a pretrained network (in
@@ -69,19 +69,19 @@ def __init__(
freeze_n: int, default = 6
number of layers to freeze. Must be less than or equal to 8. If 8
the entire 'backbone' of the nwtwork will be frozen
- head_hidden_dims: List, Optional
+ head_hidden_dims: List, Optional, default = None
List with the number of neurons per dense layer in the head. e.g: [64,32]
- head_activation: str, Optional, default = "relu"
+ head_activation: str, default = "relu"
Activation function for the dense layers in the head.
- head_dropout: float, Optional, default = 0.
+ head_dropout: float, default = 0.1
float indicating the dropout between the dense layers.
- head_batchnorm: bool, Optional, default = False
+ head_batchnorm: bool, default = False
Boolean indicating whether or not batch normalization will be applied
to the dense layers
- head_batchnorm_last: bool, Optional, default = False
+ head_batchnorm_last: bool, default = False
Boolean indicating whether or not batch normalization will be applied
to the last of the dense layers
- head_linear_first: bool, Optional, default = False
+ head_linear_first: bool, default = False
Boolean indicating the order of the operations in the dense
layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
LIN -> ACT]``
diff --git a/pytorch_widedeep/models/deep_text.py b/pytorch_widedeep/models/deep_text.py
index 92635912..275b3426 100644
--- a/pytorch_widedeep/models/deep_text.py
+++ b/pytorch_widedeep/models/deep_text.py
@@ -19,13 +19,13 @@ def __init__(
padding_idx: int = 1,
embed_dim: Optional[int] = None,
embed_matrix: Optional[np.ndarray] = None,
- embed_trainable: Optional[bool] = True,
+ embed_trainable: bool = True,
head_hidden_dims: Optional[List[int]] = None,
- head_activation: Optional[str] = "relu",
+ head_activation: str = "relu",
head_dropout: Optional[float] = None,
- head_batchnorm: Optional[bool] = False,
- head_batchnorm_last: Optional[bool] = False,
- head_linear_first: Optional[bool] = False,
+ head_batchnorm: bool = False,
+ head_batchnorm_last: bool = False,
+ head_linear_first: bool = False,
):
r"""Standard text classifier/regressor comprised by a stack of RNNs
(in particular LSTMs).
@@ -47,30 +47,30 @@ def __init__(
bidirectional: bool, default = True
indicates whether the staked RNNs are bidirectional
padding_idx: int, default = 1
- index of the padding token in the padded-tokenised sequences. default:
- 1. I use the ``fastai`` tokenizer where the token index 0 is reserved
+ index of the padding token in the padded-tokenised sequences. I
+ use the ``fastai`` tokenizer where the token index 0 is reserved
for the `'unknown'` word token
embed_dim: int, Optional, default = None
Dimension of the word embedding matrix if non-pretained word
vectors are used
embed_matrix: np.ndarray, Optional, default = None
Pretrained word embeddings
- embed_trainable: bool, Optional, default = None
+ embed_trainable: bool, default = True
Boolean indicating if the pretrained embeddings are trainable
head_hidden_dims: List, Optional, default = None
List with the sizes of the stacked dense layers in the head
e.g: [128, 64]
- head_activation: str, Optional, default = "relu"
+ head_activation: str, default = "relu"
Activation function for the dense layers in the head
head_dropout: float, Optional, default = None
dropout between the dense layers in the head
- head_batchnorm: bool, Optional, default = False
+ head_batchnorm: bool, default = False
Whether or not to include batch normalization in the dense layers that
form the `'texthead'`
- head_batchnorm_last: bool, Optional, default = False
+ head_batchnorm_last: bool, default = False
Boolean indicating whether or not to apply batch normalization to the
last of the dense layers in the head
- head_linear_first: bool, Optional, default = False
+ head_linear_first: bool, default = False
Boolean indicating whether the order of the operations in the dense
layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
LIN -> ACT]``
diff --git a/pytorch_widedeep/models/tab_mlp.py b/pytorch_widedeep/models/tab_mlp.py
index 501824f0..34835642 100644
--- a/pytorch_widedeep/models/tab_mlp.py
+++ b/pytorch_widedeep/models/tab_mlp.py
@@ -75,14 +75,14 @@ def __init__(
column_idx: Dict[str, int],
mlp_hidden_dims: List[int] = [200, 100],
mlp_activation: str = "relu",
- mlp_dropout: Optional[Union[float, List[float]]] = 0.1,
+ mlp_dropout: Union[float, List[float]] = 0.1,
mlp_batchnorm: bool = False,
mlp_batchnorm_last: bool = False,
mlp_linear_first: bool = False,
embed_input: Optional[List[Tuple[str, int, int]]] = None,
embed_dropout: float = 0.1,
continuous_cols: Optional[List[str]] = None,
- batchnorm_cont: Optional[bool] = False,
+ batchnorm_cont: bool = False,
):
r"""Defines a ``TabMlp`` model that can be used as the ``deeptabular``
component of a Wide & Deep model.
@@ -102,7 +102,7 @@ def __init__(
mlp_activation: str, default = "relu"
Activation function for the dense layers of the MLP. Currently
only "relu", "leaky_relu" and "gelu" are supported
- mlp_dropout: float or List, Optional, default = 0.1
+ mlp_dropout: float or List, default = 0.1
float or List of floats with the dropout between the dense layers.
e.g: [0.5,0.5]
mlp_batchnorm: bool, default = False
@@ -115,12 +115,12 @@ def __init__(
Boolean indicating the order of the operations in the dense
layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
LIN -> ACT]``
- embed_input: List, Optional
+ embed_input: List, Optional, default = None
List of Tuples with the column name, number of unique values and
embedding dimension. e.g. [(education, 11, 32), ...]
embed_dropout: float, default = 0.1
embeddings dropout
- continuous_cols: List, Optional
+ continuous_cols: List, Optional, default = None
List with the name of the numeric (aka continuous) columns
batchnorm_cont: bool, default = False
Boolean indicating whether or not to apply batch normalization to the
diff --git a/pytorch_widedeep/models/tab_resnet.py b/pytorch_widedeep/models/tab_resnet.py
index fc4d8587..4f5ace93 100644
--- a/pytorch_widedeep/models/tab_resnet.py
+++ b/pytorch_widedeep/models/tab_resnet.py
@@ -90,15 +90,15 @@ def __init__(
blocks_dims: List[int] = [200, 100, 100],
blocks_dropout: float = 0.1,
mlp_hidden_dims: Optional[List[int]] = None,
- mlp_activation: Optional[str] = "relu",
- mlp_dropout: Optional[float] = 0.1,
- mlp_batchnorm: Optional[bool] = False,
- mlp_batchnorm_last: Optional[bool] = False,
- mlp_linear_first: Optional[bool] = False,
- embed_dropout: Optional[float] = 0.1,
+ mlp_activation: str = "relu",
+ mlp_dropout: float = 0.1,
+ mlp_batchnorm: bool = False,
+ mlp_batchnorm_last: bool = False,
+ mlp_linear_first: bool = False,
+ embed_dropout: float = 0.1,
continuous_cols: Optional[List[str]] = None,
- batchnorm_cont: Optional[bool] = False,
- concat_cont_first: Optional[bool] = True,
+ batchnorm_cont: bool = False,
+ concat_cont_first: bool = True,
):
r"""Defines a so-called ``TabResnet`` model that can be used as the
``deeptabular`` component of a Wide & Deep model.
@@ -136,28 +136,28 @@ def __init__(
[64, 32]. If ``None`` the output of the Resnet Blocks will be
connected directly to the output neuron(s), i.e. using a MLP is
optional.
- mlp_activation: str, Optional, default = "relu"
+ mlp_activation: str, default = "relu"
Activation function for the dense layers of the MLP
- mlp_dropout: float, Optional, default = 0.1
+ mlp_dropout: float, default = 0.1
float with the dropout between the dense layers of the MLP.
- mlp_batchnorm: bool, Optional, default = False
+ mlp_batchnorm: bool, default = False
Boolean indicating whether or not batch normalization will be applied
to the dense layers
- mlp_batchnorm_last: bool, Optional, default = False
+ mlp_batchnorm_last: bool, default = False
Boolean indicating whether or not batch normalization will be applied
to the last of the dense layers
- mlp_linear_first: bool, Optional, default = False
+ mlp_linear_first: bool, default = False
Boolean indicating the order of the operations in the dense
layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
LIN -> ACT]``
- embed_dropout: float, Optional, default = 0.1
+ embed_dropout: float, default = 0.1
embeddings dropout
continuous_cols: List, Optional, default = None
List with the name of the numeric (aka continuous) columns
batchnorm_cont: bool, default = False
Boolean indicating whether or not to apply batch normalization to the
continuous input
- concat_cont_first: bool, Optional, default = True
+ concat_cont_first: bool, default = True
Boolean indicating whether the continuum columns will be
concatenated with the Embeddings and then passed through the
Resnet blocks (``True``) or, alternatively, will be concatenated
@@ -175,13 +175,13 @@ def __init__(
if ``mlp_hidden_dims`` is ``True``, this attribute will be an mlp
model that will receive:
- i) the results of the concatenation of the embeddings and the
- continuous columns -- if present -- and then passed it through
- the ``dense_resnet`` (``concat_cont_first = True``), or
+ - the results of the concatenation of the embeddings and the
+ continuous columns -- if present -- and then passed it through
+ the ``dense_resnet`` (``concat_cont_first = True``), or
- ii) the result of passing the embeddings through the ``dense_resnet``
- and then concatenating the results with the continuous columns --
- if present -- (``concat_cont_first = False``)
+ - the result of passing the embeddings through the ``dense_resnet``
+ and then concatenating the results with the continuous columns --
+ if present -- (``concat_cont_first = False``)
output_dim: `int`
The output dimension of the model. This is a required attribute
diff --git a/pytorch_widedeep/models/tab_transformer.py b/pytorch_widedeep/models/tab_transformer.py
index 0ae7b0f7..5f955c50 100644
--- a/pytorch_widedeep/models/tab_transformer.py
+++ b/pytorch_widedeep/models/tab_transformer.py
@@ -222,10 +222,10 @@ def __init__(
ff_hidden_dim: int = 32 * 4,
transformer_activation: str = "gelu",
mlp_hidden_dims: Optional[List[int]] = None,
- mlp_activation: Optional[str] = "relu",
- mlp_batchnorm: Optional[bool] = False,
- mlp_batchnorm_last: Optional[bool] = False,
- mlp_linear_first: Optional[bool] = True,
+ mlp_activation: str = "relu",
+ mlp_batchnorm: bool = False,
+ mlp_batchnorm_last: bool = False,
+ mlp_linear_first: bool = True,
):
r"""TabTransformer model (https://arxiv.org/pdf/2012.06678.pdf) model that
@@ -248,7 +248,8 @@ def __init__(
full_embed_dropout: bool, default = False
Boolean indicating if an entire embedding (i.e. the representation
for one categorical column) will be dropped in the batch. See:
- ``pytorch_widedeep.model.tab_transformer.FullEmbeddingDropout``
+ ``pytorch_widedeep.model.tab_transformer.FullEmbeddingDropout``.
+ If ``full_embed_dropout = True``, ``embed_dropout`` is ignored.
shared_embed: bool, default = False
The idea behind ``shared_embed`` is described in the Appendix A in the paper:
`'The goal of having column embedding is to enable the model to distinguish the
@@ -274,7 +275,7 @@ def __init__(
``pytorch_widedeep.model.tab_transformer.TransformerEncoder``) and the
output MLP
keep_attn_weights: bool, default = False
- If set to ``True`` the model will store the attention weights in the ``blk.self_attn.attn_weights``
+ If set to ``True`` the model will store the attention weights in the ``attention_weights``
attribute.
fixed_attention: bool, default = False
If set to ``True`` all the observations in a batch will have the
@@ -290,17 +291,17 @@ def __init__(
transformer_activation: str, default = "gelu"
Transformer Encoder activation function
mlp_hidden_dims: List, Optional, default = None
- MLP hidden dimensions. If not provided it will default to ``[4*l, 2*l]`` where l is the
- mlp input dimension
- mlp_activation: str, Optional, default = "gelu"
+ MLP hidden dimensions. If not provided it will default to ``[4*l,
+ 2*l]`` where ``l`` is the mlp input dimension
+ mlp_activation: str, default = "gelu"
MLP activation function
- mlp_batchnorm: bool, Optional, default = False
+ mlp_batchnorm: bool, default = False
Boolean indicating whether or not to apply batch normalization to the
dense layers
- mlp_batchnorm_last: bool, Optional, default = False
+ mlp_batchnorm_last: bool, default = False
Boolean indicating whether or not to apply batch normalization to the
last of the dense layers
- mlp_linear_first: bool, Optional, default = False
+ mlp_linear_first: bool, default = False
Boolean indicating whether the order of the operations in the dense
layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
LIN -> ACT]``
diff --git a/pytorch_widedeep/models/wide_deep.py b/pytorch_widedeep/models/wide_deep.py
index ce42766d..4945d6cf 100644
--- a/pytorch_widedeep/models/wide_deep.py
+++ b/pytorch_widedeep/models/wide_deep.py
@@ -21,11 +21,11 @@ def __init__(
deepimage: Optional[nn.Module] = None,
deephead: Optional[nn.Module] = None,
head_hidden_dims: Optional[List[int]] = None,
- head_activation: Optional[str] = "relu",
- head_dropout: Optional[float] = 0.1,
- head_batchnorm: Optional[bool] = False,
- head_batchnorm_last: Optional[bool] = False,
- head_linear_first: Optional[bool] = False,
+ head_activation: str = "relu",
+ head_dropout: float = 0.1,
+ head_batchnorm: bool = False,
+ head_batchnorm_last: bool = False,
+ head_linear_first: bool = False,
pred_dim: int = 1,
):
r"""Main collector class that combines all ``wide``, ``deeptabular``
@@ -52,7 +52,7 @@ def __init__(
currently ``pytorch-widedeep`` implements three possible
architectures for the `deeptabular` component. These are:
- ``TabMlp``, ``TabResnet`` and ` ``TabTransformer``.
+ ``TabMlp``, ``TabResnet`` and ``TabTransformer``.
1. ``TabMlp`` is simply an embedding layer encoding the categorical
features that are then concatenated and passed through a series of
@@ -90,32 +90,30 @@ def __init__(
head_hidden_dims: List, Optional, default = None
Alternatively, the ``head_hidden_dims`` param can be used to
specify the sizes of the stacked dense layers in the fc-head e.g:
- ``[128, 64]``
- head_dropout: float, Optional, default = 0.1
- Dropout between the layers in ``head_hidden_dims``
+ ``[128, 64]``. Use ``deephead`` or ``head_hidden_dims``, but not
+ both.
+ head_dropout: float, default = 0.1
+ If ``head_hidden_dims`` is not None, dropout between the layers in
+ ``head_hidden_dims``
head_activation: str, default = "relu"
- activation function of the head layers. One of "relu", gelu" or
- "leaky_relu"
- head_batchnorm: bool, Optional, default = False
- Specifies if batch normalizatin should be included in the head layers
- head_batchnorm_last: bool, Optional, default = False
- Boolean indicating whether or not to apply batch normalization to the
- last of the dense layers
- head_linear_first: bool, Optional, default = False
- Boolean indicating whether the order of the operations in the
- dense layer. If ``True``: ``[LIN -> ACT -> BN -> DP]``. If
- ``False``: ``[BN -> DP -> LIN -> ACT]``
+ If ``head_hidden_dims`` is not None, activation function of the
+ head layers. One of "relu", gelu" or "leaky_relu"
+ head_batchnorm: bool, default = False
+ If ``head_hidden_dims`` is not None, specifies if batch
+ normalizatin should be included in the head layers
+ head_batchnorm_last: bool, default = False
+ If ``head_hidden_dims`` is not None, boolean indicating whether or
+ not to apply batch normalization to the last of the dense layers
+ head_linear_first: bool, default = False
+ If ``head_hidden_dims`` is not None, boolean indicating whether
+ the order of the operations in the dense layer. If ``True``:
+ ``[LIN -> ACT -> BN -> DP]``. If ``False``: ``[BN -> DP -> LIN ->
+ ACT]``
pred_dim: int, default = 1
Size of the final wide and deep output layer containing the
predictions. `1` for regression and binary classification or number
of classes for multiclass classification.
- Attributes
- ----------
- cyclic_lr: bool
- Attribute that indicates if any of the lr_schedulers is cyclic_lr (i.e. ``CyclicLR`` or
- ``OneCycleLR``). See `Pytorch schedulers `_.
-
Examples
--------
diff --git a/pytorch_widedeep/preprocessing/preprocessors.py b/pytorch_widedeep/preprocessing/preprocessors.py
index db2a10cd..50d941fa 100644
--- a/pytorch_widedeep/preprocessing/preprocessors.py
+++ b/pytorch_widedeep/preprocessing/preprocessors.py
@@ -213,7 +213,7 @@ def __init__(
continuous_cols: List[str] = None,
scale: bool = True,
default_embed_dim: int = 16,
- already_standard: Optional[List[str]] = None,
+ already_standard: List[str] = None,
for_tabtransformer: bool = False,
verbose: int = 1,
):
@@ -236,17 +236,23 @@ def __init__(
:class:`pytorch_widedeep.models`
default_embed_dim: int, default=16
Dimension for the embeddings used for the ``deeptabular``
- component
- already_standard: List, Optional, default = None
+ component if the embed_dim is not provided in the ``embed_cols``
+ parameter
+ already_standard: List, default = None
List with the name of the continuous cols that do not need to be
- Standarised.
+ Standarised. For example, you might have Long and Lat in your
+ dataset and might want to encode them somehow (e.g. see the
+ ``LatLongScalarEnc`` available in the `autogluon
+ `_
+ tabular library) and NOT standarize them any further
for_tabtransformer: bool, default = False
Boolean indicating whether the preprocessed data will be passed to
a ``TabTransformer`` model. If ``True``, the param ``embed_cols``
must just be a list containing the categorical columns: e.g.:
['education', 'relationship', ...] This is because following the
results in the `paper `_,
- they will all be encoded using embeddings of dim 32. See
+ they will all be encoded using embeddings of the same dim (32 by
+ default). See
:class:`pytorch_widedeep.models.tab_transformer.TabTransformer`
verbose: int, default = 1
diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py
index 3fc92891..88114a51 100644
--- a/pytorch_widedeep/training/trainer.py
+++ b/pytorch_widedeep/training/trainer.py
@@ -106,12 +106,15 @@ def __init__(
'binary', 'multiclass' or 'regression', consistent with the loss
function
- optimizers: ``Optimzer`` or Dict, Optional, default= ``AdamW``
+ optimizers: ``Optimzer`` or Dict, Optional, default= None
- An instance of Pytorch's ``Optimizer`` object (e.g. :obj:`torch.optim.Adam()`) or
- a dictionary where there keys are the model components (i.e.
`'wide'`, `'deeptabular'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`) and
the values are the corresponding optimizers. If multiple optimizers are used
the dictionary **MUST** contain an optimizer per model component.
+
+ if no optimizers are passed it will default to ``AdamW`` for all
+ Wide and Deep components
lr_schedulers: ``LRScheduler`` or Dict, Optional, default=None
- An instance of Pytorch's ``LRScheduler`` object (e.g
:obj:`torch.optim.lr_scheduler.StepLR(opt, step_size=5)`) or
@@ -130,11 +133,11 @@ def __init__(
`_.
callbacks: List, Optional, default=None
List with ``Callback`` objects. The four callbacks available in
- ``pytorch-widedeep`` are: ``ModelCheckpoint``, ``EarlyStopping``,
- and ``LRHistory``. The ``History`` callback is used by default.
- This can also be a custom callback as long as the object of type
- ``Callback``. See ``pytorch_widedeep.callbacks.Callback`` or the
- `Examples
+ ``pytorch-widedeep`` are: ``History``, ``ModelCheckpoint``,
+ ``EarlyStopping``, and ``LRHistory``. The ``History`` callback is
+ used by default. This can also be a custom callback as long as the
+ object of type ``Callback``. See
+ ``pytorch_widedeep.callbacks.Callback`` or the `Examples
`_
folder in the repo
metrics: List, Optional, default=None
@@ -168,6 +171,13 @@ def __init__(
seed: int, default=1
Random seed to be used internally for train_test_split
+ Attributes
+ ----------
+ cyclic_lr: bool
+ Attribute that indicates if any of the lr_schedulers is cyclic_lr (i.e. ``CyclicLR`` or
+ ``OneCycleLR``). See `Pytorch schedulers `_.
+
+
Example
--------
>>> import torch
@@ -276,26 +286,30 @@ def fit( # noqa: C901
val_split: Optional[float] = None,
target: Optional[np.ndarray] = None,
n_epochs: int = 1,
- validation_freq: Optional[int] = 1,
+ validation_freq: int = 1,
batch_size: int = 32,
patience: int = 10,
- finetune: Optional[bool] = False,
- finetune_epochs: Optional[int] = 5,
- finetune_max_lr: Optional[float] = 0.01,
- finetune_deeptabular_gradual: Optional[bool] = False,
- finetune_deeptabular_max_lr: Optional[float] = 0.01,
+ finetune: bool = False,
+ finetune_epochs: int = 5,
+ finetune_max_lr: float = 0.01,
+ finetune_deeptabular_gradual: bool = False,
+ finetune_deeptabular_max_lr: float = 0.01,
finetune_deeptabular_layers: Optional[List[nn.Module]] = None,
- finetune_deeptext_gradual: Optional[bool] = False,
- finetune_deeptext_max_lr: Optional[float] = 0.01,
+ finetune_deeptext_gradual: bool = False,
+ finetune_deeptext_max_lr: float = 0.01,
finetune_deeptext_layers: Optional[List[nn.Module]] = None,
- finetune_deepimage_gradual: Optional[bool] = False,
- finetune_deepimage_max_lr: Optional[float] = 0.01,
+ finetune_deepimage_gradual: bool = False,
+ finetune_deepimage_max_lr: float = 0.01,
finetune_deepimage_layers: Optional[List[nn.Module]] = None,
- finetune_routine: Optional[str] = "howard",
- stop_after_finetuning: Optional[bool] = False,
+ finetune_routine: str = "howard",
+ stop_after_finetuning: bool = False,
):
r"""Fit method.
+ The input datasets can be passed either directly via numpy arrays
+ (``X_wide``, ``X_tab``, ``X_text`` or ``X_img``) or alternatively, in
+ dictionaries (``X_train`` or ``X_val``).
+
Parameters
----------
X_wide: np.ndarray, Optional. default=None
@@ -311,13 +325,13 @@ def fit( # noqa: C901
Input for the ``deepimage`` model component.
See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
X_train: Dict, Optional. default=None
- Training dataset for the different model components. Keys are
- `X_wide`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`. Values are
- the corresponding matrices.
+ The training dataset can also be passed in a dictionary. Keys are
+ `X_wide`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`. Values
+ are the corresponding matrices.
X_val: Dict, Optional. default=None
- Validation dataset for the different model component. Keys are
- `'X_wide'`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`. Values are
- the corresponding matrices.
+ The validation dataset can also be passed in a dictionary. Keys
+ are `X_wide`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`.
+ Values are the corresponding matrices.
val_split: float, Optional. default=None
train/val split fraction
target: np.ndarray, Optional. default=None
@@ -330,7 +344,7 @@ def fit( # noqa: C901
patience: int, default=10
Number of epochs without improving the target metric or loss
before the fit process stops
- finetune: bool, Optional, default=False
+ finetune: bool, default=False
param alias: ``warmup``
fine-tune individual model components.
@@ -363,13 +377,13 @@ def fit( # noqa: C901
section in this documentation and the `Examples
`_
folder in the repo.
- finetune_epochs: int, Optional, default=4
+ finetune_epochs: int, default=4
param alias: ``warmup_epochs``
Number of fine-tune epochs for those model components that will
*NOT* be gradually fine-tuned. Those components with gradual
fine-tune follow their corresponding specific routine.
- finetune_max_lr: float, Optional, default=0.01
+ finetune_max_lr: float, default=0.01
param alias: ``warmup_max_lr``
Maximum learning rate during the Triangular Learning rate cycle
@@ -379,7 +393,7 @@ def fit( # noqa: C901
Boolean indicating if the ``deeptabular`` component will be
fine-tuned gradually
- finetune_deeptabular_max_lr: float, Optional, default=0.01
+ finetune_deeptabular_max_lr: float, default=0.01
param alias: ``warmup_deeptabular_max_lr``
Maximum learning rate during the Triangular Learning rate cycle
@@ -392,12 +406,12 @@ def fit( # noqa: C901
.. note:: These have to be in `fine-tune-order`: the layers or blocks
close to the output neuron(s) first
- finetune_deeptext_gradual: bool, Optional, default=False
+ finetune_deeptext_gradual: bool, default=False
param alias: ``warmup_deeptext_gradual``
Boolean indicating if the ``deeptext`` component will be
fine-tuned gradually
- finetune_deeptext_max_lr: float, Optional, default=0.01
+ finetune_deeptext_max_lr: float, default=0.01
param alias: ``warmup_deeptext_max_lr``
Maximum learning rate during the Triangular Learning rate cycle
@@ -410,12 +424,12 @@ def fit( # noqa: C901
.. note:: These have to be in `fine-tune-order`: the layers or blocks
close to the output neuron(s) first
- finetune_deepimage_gradual: bool, Optional, default=False
+ finetune_deepimage_gradual: bool, default=False
param alias: ``warmup_deepimage_gradual``
Boolean indicating if the ``deepimage`` component will be
fine-tuned gradually
- finetune_deepimage_max_lr: float, Optional, default=0.01
+ finetune_deepimage_max_lr: float, default=0.01
param alias: ``warmup_deepimage_max_lr``
Maximum learning rate during the Triangular Learning rate cycle
@@ -428,10 +442,10 @@ def fit( # noqa: C901
.. note:: These have to be in `fine-tune-order`: the layers or blocks
close to the output neuron(s) first
- finetune_routine: str, Optional, default=`felbo`
- param alias: ``warmup_deepimage_layers``
+ finetune_routine: str, default = "howard"
+ param alias: ``warmup_routine``
- Warm up routine. On of `felbo` or `howard`. See the examples
+ Warm up routine. On of "felbo" or "howard". See the examples
section in this documentation and the corresponding repo for
details on how to use fine-tune routines
@@ -581,6 +595,11 @@ def predict( # type: ignore[return]
) -> np.ndarray:
r"""Returns the predictions
+ The input datasets can be passed either directly via numpy arrays
+ (``X_wide``, ``X_tab``, ``X_text`` or ``X_img``) or alternatively, in
+ a dictionary (``X_test``)
+
+
Parameters
----------
X_wide: np.ndarray, Optional. default=None
@@ -596,9 +615,9 @@ def predict( # type: ignore[return]
Input for the ``deepimage`` model component.
See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
X_test: Dict, Optional. default=None
- Dictionary with the resting dataset for the different model
- components. Keys are `'X_wide'`, `'X_tab'`, `'X_text'` and
- `'X_img'` and the values are the corresponding matrices.
+ The test dataset can also be passed in a dictionary. Keys are
+ `X_wide`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`. Values
+ are the corresponding matrices.
"""
preds_l = self._predict(X_wide, X_tab, X_text, X_img, X_test)
@@ -622,6 +641,10 @@ def predict_proba( # type: ignore[return]
r"""Returns the predicted probabilities for the test dataset for binary
and multiclass methods
+ The input datasets can be passed either directly via numpy arrays
+ (``X_wide``, ``X_tab``, ``X_text`` or ``X_img``) or alternatively, in
+ a dictionary (``X_test``)
+
Parameters
----------
X_wide: np.ndarray, Optional. default=None
@@ -637,9 +660,9 @@ def predict_proba( # type: ignore[return]
Input for the ``deepimage`` model component.
See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
X_test: Dict, Optional. default=None
- Dictionary with the resting dataset for the different model
- components. Keys are `'X_wide'`, `'X_tab'`, `'X_text'` and
- `'X_img'` and the values are the corresponding matrices.
+ The test dataset can also be passed in a dictionary. Keys are
+ `X_wide`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`. Values
+ are the corresponding matrices.
"""
preds_l = self._predict(X_wide, X_tab, X_text, X_img, X_test)
diff --git a/pytorch_widedeep/utils/deeptabular_utils.py b/pytorch_widedeep/utils/deeptabular_utils.py
index 1c4c01dd..ead574c3 100644
--- a/pytorch_widedeep/utils/deeptabular_utils.py
+++ b/pytorch_widedeep/utils/deeptabular_utils.py
@@ -96,7 +96,7 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
return df_inp
def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:
- """Applies the full process
+ """Combines ``fit`` and ``transform``
Examples
--------
diff --git a/pytorch_widedeep/utils/fastai_transforms.py b/pytorch_widedeep/utils/fastai_transforms.py
index 543d573c..8a0e79ab 100644
--- a/pytorch_widedeep/utils/fastai_transforms.py
+++ b/pytorch_widedeep/utils/fastai_transforms.py
@@ -217,10 +217,10 @@ def __init__(
self,
tok_func: Callable = SpacyTokenizer,
lang: str = "en",
- pre_rules: ListRules = None,
- post_rules: ListRules = None,
- special_cases: Collection[str] = None,
- n_cpus: int = None,
+ pre_rules: Optional[ListRules] = None,
+ post_rules: Optional[ListRules] = None,
+ special_cases: Optional[Collection[str]] = None,
+ n_cpus: Optional[int] = None,
):
"""Class to combine a series of rules and a tokenizer function to tokenize
text with multiprocessing.
@@ -231,16 +231,16 @@ def __init__(
Tokenizer Object. See :class:`pytorch_widedeep.utils.fastai_transforms.SpacyTokenizer`
lang: str, default = "en",
Text's Language
- pre_rules: ListRules, default = None,
+ pre_rules: ListRules, Optional, default = None,
Custom type: ``Collection[Callable[[str], str]]``.
see :obj:`pytorch_widedeep.wdtypes`. Preprocessing Rules
- post_rules: ListRules, default = None,
+ post_rules: ListRules, Optional, default = None,
Custom type: ``Collection[Callable[[str], str]]``.
see :obj:`pytorch_widedeep.wdtypes`. Postprocessing Rules
- special_cases: Collection, default= None,
+ special_cases: Collection, Optional, default= None,
special cases to be added to the tokenizer via ``Spacy``'s
``add_special_case`` method
- n_cpus: int, default = None
+ n_cpus: int, Optional, default = None
number of CPUs to used during the tokenization process
"""
self.tok_func, self.lang, self.special_cases = tok_func, lang, special_cases
@@ -267,7 +267,8 @@ def process_text(self, t: str, tok: BaseTokenizer) -> List[str]:
t: str
text to be processed and tokenized
tok: ``BaseTokenizer``
- Instance of :obj:`BaseTokenizer`
+ Instance of :obj:`BaseTokenizer`. See
+ ``pytorch_widedeep.utils.fastai_transforms.BaseTokenizer``
"""
for rule in self.pre_rules:
t = rule(t)
diff --git a/tests/test_model_functioning/test_initializers.py b/tests/test_model_functioning/test_initializers.py
index 1f1afd44..4bafbfb5 100644
--- a/tests/test_model_functioning/test_initializers.py
+++ b/tests/test_model_functioning/test_initializers.py
@@ -180,11 +180,11 @@ def test_initializers_with_pattern():
)
def test_single_initializer(model, initializer):
- inp_weights = model.wide.wide_linear.weight.data.detach()
+ inp_weights = model.wide.wide_linear.weight.data.detach().cpu()
n_model = c(model)
trainer = Trainer(n_model, objective="binary", initializers=initializer)
- init_weights = trainer.model.wide.wide_linear.weight.data
+ init_weights = trainer.model.wide.wide_linear.weight.data.detach().cpu()
assert not torch.all(inp_weights == init_weights)