Skip to content

Commit

Permalink
Reviewed all the writting. Docs, README and examples. Adusted and ref…
Browse files Browse the repository at this point in the history
…ine types. Ready to publish
  • Loading branch information
jrzaurin committed Feb 11, 2021
1 parent 83ccc5c commit 56bd75e
Show file tree
Hide file tree
Showing 19 changed files with 202 additions and 180 deletions.
25 changes: 12 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ using wide and deep models.

### Introduction

`pytorch-widedeep` is based on Google's Wide and Deep Algorithm. Details of
the original algorithm can be found
[here](https://www.tensorflow.org/tutorials/wide_and_deep), and the research
paper can be found [here](https://arxiv.org/abs/1606.07792).
`pytorch-widedeep` is based on Google's Wide and Deep Algorithm, [Wide & Deep
Learning for Recommender Systems](https://arxiv.org/abs/1606.07792).

In general terms, `pytorch-widedeep` is a package to use deep learning with
tabular data. In particular, is intended to facilitate the combination of text
Expand All @@ -40,18 +38,20 @@ in the Figure below:
</p>

The dashed boxes in the figure represent optional, overall components, and the
dashed lines indicate the corresponding connections, depending on whether or
not certain components are present. For example, the dashed, blue-lines
indicate that the ``deeptabular``, ``deeptext`` and ``deepimage`` components
are connected directly to the output neuron or neurons (depending on whether
we are performing a binary classification or regression, or a multi-class
classification) if the optional ``deephead`` is not present. The components
within the faded-pink rectangle are concatenated.
dashed lines/arrows indicate the corresponding connections, depending on
whether or not certain components are present. For example, the dashed,
blue-lines indicate that the ``deeptabular``, ``deeptext`` and ``deepimage``
components are connected directly to the output neuron or neurons (depending
on whether we are performing a binary classification or regression, or a
multi-class classification) if the optional ``deephead`` is not present.
Finally, the components within the faded-pink rectangle are concatenated.

Note that it is not possible to illustrate the number of possible
architectures and components available in ``pytorch-widedeep`` in one Figure.
Therefore, for more details on possible architectures (and more) please, see
the documentation, or the Examples folders and the notebooks there.
the
[documentation]((https://pytorch-widedeep.readthedocs.io/en/latest/index.html)),
or the Examples folders and the notebooks there.

In math terms, and following the notation in the
[paper](https://arxiv.org/abs/1606.07792), the expression for the architecture
Expand Down Expand Up @@ -187,7 +187,6 @@ from pytorch_widedeep.metrics import Accuracy
# the following 4 lines are not directly related to ``pytorch-widedeep``. I
# assume you have downloaded the dataset and place it in a dir called
# data/adult/

df = pd.read_csv("data/adult/adult.csv.zip")
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop("income", axis=1, inplace=True)
Expand Down
4 changes: 3 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ within the faded-pink rectangle are concatenated.
Note that it is not possible to illustrate the number of possible
architectures and components available in ``pytorch-widedeep`` in one Figure.
Therefore, for more details on possible architectures (and more) please, read
this documentation, or see the Examples folders in the repo.
this documentation, or see the `Examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_ folders
in the repo.

In math terms, and following the notation in the `paper
<https://arxiv.org/abs/1606.07792>`_, the expression for the architecture
Expand Down
6 changes: 3 additions & 3 deletions docs/model_components.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
The ``models`` module
======================

This module contains the four main components that will comprise Wide and Deep
model, and the ``WideDeep`` "constructor" class. These four components are:
``wide``, ``deeptabular``, ``deeptext``, ``deepimage``.
This module contains the four main components that will comprise a Wide and
Deep model, and the ``WideDeep`` "constructor" class. These four components
are: ``wide``, ``deeptabular``, ``deeptext``, ``deepimage``.

.. note:: ``TabMlp``, ``TabResnet`` and ``TabTransformer`` can all be used
as the ``deeptabular`` component of the model and simply represent
Expand Down
2 changes: 1 addition & 1 deletion docs/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Training wide and deep models for tabular data

`...` or just deep learning models for tabular data.

Here is the documentation for the ``Trainer`` class, that will do all the heavy lifting
Here is the documentation for the ``Trainer`` class, that will do all the heavy lifting.

Trainer is also available from ``pytorch-widedeep`` directly, for example, one could do:

Expand Down
7 changes: 4 additions & 3 deletions docs/utils/fastai_transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ I have directly copied and pasted part of the ``transforms.py`` module from
the ``fastai`` library. The reason to do such a thing is because
``pytorch_widedeep`` only needs the ``Tokenizer`` and the ``Vocab`` classes
there. This way I avoid extra dependencies. Credit for all the code in the
``fastai_transforms`` module to Jeremy Howard and the `fastai` team. I only
include the documentation here for completion, but I strongly advise the user
to read the ``fastai`` `documentation <https://docs.fast.ai/>`_.
``fastai_transforms`` module in this ``pytorch-widedeep`` package goes to
Jeremy Howard and the `fastai` team. I only include the documentation here for
completion, but I strongly advise the user to read the ``fastai``
`documentation <https://docs.fast.ai/>`_.

.. autoclass:: pytorch_widedeep.utils.fastai_transforms.Tokenizer
:members:
Expand Down
2 changes: 1 addition & 1 deletion examples/03_Binary_Classification_with_Defaults.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The only requisite is that the model component must be passed to `WideDeep` before \"fed\" to the `Trainer`. This is because the `Trainer` is coded so that it trains a model that has a parent called `model` and then children that correspond to the model components: `wide`, `deeptabular`, `deeptext` and `deepimage`. "
"The only requisite is that the model component must be passed to `WideDeep` before \"fed\" to the `Trainer`. This is because the `Trainer` is coded so that it trains a model that has a parent called `model` and then children that correspond to the model components: `wide`, `deeptabular`, `deeptext` and `deepimage`. Also, `WideDeep` builds the last connection between the output of those components and the final, output neuron(s)."
]
}
],
Expand Down
4 changes: 2 additions & 2 deletions examples/06_FineTune_and_WarmUp_Model_Components.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@
"metadata": {},
"outputs": [],
"source": [
"trainer2 = Trainer(model=\"models_dir/model1.t\", objective=\"binary\", metrics=[Accuracy])"
"trainer2 = Trainer(model_path=\"models_dir/model1.t\", objective=\"binary\", metrics=[Accuracy])"
]
},
{
Expand Down Expand Up @@ -949,7 +949,7 @@
"metadata": {},
"outputs": [],
"source": [
"trainer7 = Trainer(model=\"models_dir/model3.t\", objective=\"binary\", metrics=[Accuracy])"
"trainer7 = Trainer(model_path=\"models_dir/model3.t\", objective=\"binary\", metrics=[Accuracy])"
]
},
{
Expand Down
13 changes: 2 additions & 11 deletions pypi_README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@ using wide and deep models.

### Introduction

`pytorch-widedeep` is based on Google's Wide and Deep Algorithm. Details of
the original algorithm can be found
[here](https://www.tensorflow.org/tutorials/wide_and_deep), and the research
paper can be found [here](https://arxiv.org/abs/1606.07792).
`pytorch-widedeep` is based on Google's Wide and Deep Algorithm, [Wide & Deep
Learning for Recommender Systems](https://arxiv.org/abs/1606.07792).

In general terms, `pytorch-widedeep` is a package to use deep learning with
tabular data. In particular, is intended to facilitate the combination of text
Expand Down Expand Up @@ -84,12 +82,6 @@ Binary classification with the [adult
dataset]([adult](https://www.kaggle.com/wenruliu/adult-income-dataset))
using `Wide` and `DeepDense` and defaults settings.

### Quick start

Binary classification with the [adult
dataset]([adult](https://www.kaggle.com/wenruliu/adult-income-dataset))
using `Wide` and `DeepDense` and defaults settings.


```python
```
Expand All @@ -110,7 +102,6 @@ from pytorch_widedeep.metrics import Accuracy
# the following 4 lines are not directly related to ``pytorch-widedeep``. I
# assume you have downloaded the dataset and place it in a dir called
# data/adult/

df = pd.read_csv("data/adult/adult.csv.zip")
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop("income", axis=1, inplace=True)
Expand Down
22 changes: 11 additions & 11 deletions pytorch_widedeep/models/deep_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def __init__(
resnet_architecture: int = 18,
freeze_n: int = 6,
head_hidden_dims: Optional[List[int]] = None,
head_activation: Optional[str] = "relu",
head_dropout: Optional[float] = None,
head_batchnorm: Optional[bool] = False,
head_batchnorm_last: Optional[bool] = False,
head_linear_first: Optional[bool] = False,
head_activation: str = "relu",
head_dropout: float = 0.1,
head_batchnorm: bool = False,
head_batchnorm_last: bool = False,
head_linear_first: bool = False,
):
r"""
Standard image classifier/regressor using a pretrained network (in
Expand Down Expand Up @@ -69,19 +69,19 @@ def __init__(
freeze_n: int, default = 6
number of layers to freeze. Must be less than or equal to 8. If 8
the entire 'backbone' of the nwtwork will be frozen
head_hidden_dims: List, Optional
head_hidden_dims: List, Optional, default = None
List with the number of neurons per dense layer in the head. e.g: [64,32]
head_activation: str, Optional, default = "relu"
head_activation: str, default = "relu"
Activation function for the dense layers in the head.
head_dropout: float, Optional, default = 0.
head_dropout: float, default = 0.1
float indicating the dropout between the dense layers.
head_batchnorm: bool, Optional, default = False
head_batchnorm: bool, default = False
Boolean indicating whether or not batch normalization will be applied
to the dense layers
head_batchnorm_last: bool, Optional, default = False
head_batchnorm_last: bool, default = False
Boolean indicating whether or not batch normalization will be applied
to the last of the dense layers
head_linear_first: bool, Optional, default = False
head_linear_first: bool, default = False
Boolean indicating the order of the operations in the dense
layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
LIN -> ACT]``
Expand Down
24 changes: 12 additions & 12 deletions pytorch_widedeep/models/deep_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ def __init__(
padding_idx: int = 1,
embed_dim: Optional[int] = None,
embed_matrix: Optional[np.ndarray] = None,
embed_trainable: Optional[bool] = True,
embed_trainable: bool = True,
head_hidden_dims: Optional[List[int]] = None,
head_activation: Optional[str] = "relu",
head_activation: str = "relu",
head_dropout: Optional[float] = None,
head_batchnorm: Optional[bool] = False,
head_batchnorm_last: Optional[bool] = False,
head_linear_first: Optional[bool] = False,
head_batchnorm: bool = False,
head_batchnorm_last: bool = False,
head_linear_first: bool = False,
):
r"""Standard text classifier/regressor comprised by a stack of RNNs
(in particular LSTMs).
Expand All @@ -47,30 +47,30 @@ def __init__(
bidirectional: bool, default = True
indicates whether the staked RNNs are bidirectional
padding_idx: int, default = 1
index of the padding token in the padded-tokenised sequences. default:
1. I use the ``fastai`` tokenizer where the token index 0 is reserved
index of the padding token in the padded-tokenised sequences. I
use the ``fastai`` tokenizer where the token index 0 is reserved
for the `'unknown'` word token
embed_dim: int, Optional, default = None
Dimension of the word embedding matrix if non-pretained word
vectors are used
embed_matrix: np.ndarray, Optional, default = None
Pretrained word embeddings
embed_trainable: bool, Optional, default = None
embed_trainable: bool, default = True
Boolean indicating if the pretrained embeddings are trainable
head_hidden_dims: List, Optional, default = None
List with the sizes of the stacked dense layers in the head
e.g: [128, 64]
head_activation: str, Optional, default = "relu"
head_activation: str, default = "relu"
Activation function for the dense layers in the head
head_dropout: float, Optional, default = None
dropout between the dense layers in the head
head_batchnorm: bool, Optional, default = False
head_batchnorm: bool, default = False
Whether or not to include batch normalization in the dense layers that
form the `'texthead'`
head_batchnorm_last: bool, Optional, default = False
head_batchnorm_last: bool, default = False
Boolean indicating whether or not to apply batch normalization to the
last of the dense layers in the head
head_linear_first: bool, Optional, default = False
head_linear_first: bool, default = False
Boolean indicating whether the order of the operations in the dense
layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
LIN -> ACT]``
Expand Down
10 changes: 5 additions & 5 deletions pytorch_widedeep/models/tab_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ def __init__(
column_idx: Dict[str, int],
mlp_hidden_dims: List[int] = [200, 100],
mlp_activation: str = "relu",
mlp_dropout: Optional[Union[float, List[float]]] = 0.1,
mlp_dropout: Union[float, List[float]] = 0.1,
mlp_batchnorm: bool = False,
mlp_batchnorm_last: bool = False,
mlp_linear_first: bool = False,
embed_input: Optional[List[Tuple[str, int, int]]] = None,
embed_dropout: float = 0.1,
continuous_cols: Optional[List[str]] = None,
batchnorm_cont: Optional[bool] = False,
batchnorm_cont: bool = False,
):
r"""Defines a ``TabMlp`` model that can be used as the ``deeptabular``
component of a Wide & Deep model.
Expand All @@ -102,7 +102,7 @@ def __init__(
mlp_activation: str, default = "relu"
Activation function for the dense layers of the MLP. Currently
only "relu", "leaky_relu" and "gelu" are supported
mlp_dropout: float or List, Optional, default = 0.1
mlp_dropout: float or List, default = 0.1
float or List of floats with the dropout between the dense layers.
e.g: [0.5,0.5]
mlp_batchnorm: bool, default = False
Expand All @@ -115,12 +115,12 @@ def __init__(
Boolean indicating the order of the operations in the dense
layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
LIN -> ACT]``
embed_input: List, Optional
embed_input: List, Optional, default = None
List of Tuples with the column name, number of unique values and
embedding dimension. e.g. [(education, 11, 32), ...]
embed_dropout: float, default = 0.1
embeddings dropout
continuous_cols: List, Optional
continuous_cols: List, Optional, default = None
List with the name of the numeric (aka continuous) columns
batchnorm_cont: bool, default = False
Boolean indicating whether or not to apply batch normalization to the
Expand Down
42 changes: 21 additions & 21 deletions pytorch_widedeep/models/tab_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ def __init__(
blocks_dims: List[int] = [200, 100, 100],
blocks_dropout: float = 0.1,
mlp_hidden_dims: Optional[List[int]] = None,
mlp_activation: Optional[str] = "relu",
mlp_dropout: Optional[float] = 0.1,
mlp_batchnorm: Optional[bool] = False,
mlp_batchnorm_last: Optional[bool] = False,
mlp_linear_first: Optional[bool] = False,
embed_dropout: Optional[float] = 0.1,
mlp_activation: str = "relu",
mlp_dropout: float = 0.1,
mlp_batchnorm: bool = False,
mlp_batchnorm_last: bool = False,
mlp_linear_first: bool = False,
embed_dropout: float = 0.1,
continuous_cols: Optional[List[str]] = None,
batchnorm_cont: Optional[bool] = False,
concat_cont_first: Optional[bool] = True,
batchnorm_cont: bool = False,
concat_cont_first: bool = True,
):
r"""Defines a so-called ``TabResnet`` model that can be used as the
``deeptabular`` component of a Wide & Deep model.
Expand Down Expand Up @@ -136,28 +136,28 @@ def __init__(
[64, 32]. If ``None`` the output of the Resnet Blocks will be
connected directly to the output neuron(s), i.e. using a MLP is
optional.
mlp_activation: str, Optional, default = "relu"
mlp_activation: str, default = "relu"
Activation function for the dense layers of the MLP
mlp_dropout: float, Optional, default = 0.1
mlp_dropout: float, default = 0.1
float with the dropout between the dense layers of the MLP.
mlp_batchnorm: bool, Optional, default = False
mlp_batchnorm: bool, default = False
Boolean indicating whether or not batch normalization will be applied
to the dense layers
mlp_batchnorm_last: bool, Optional, default = False
mlp_batchnorm_last: bool, default = False
Boolean indicating whether or not batch normalization will be applied
to the last of the dense layers
mlp_linear_first: bool, Optional, default = False
mlp_linear_first: bool, default = False
Boolean indicating the order of the operations in the dense
layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
LIN -> ACT]``
embed_dropout: float, Optional, default = 0.1
embed_dropout: float, default = 0.1
embeddings dropout
continuous_cols: List, Optional, default = None
List with the name of the numeric (aka continuous) columns
batchnorm_cont: bool, default = False
Boolean indicating whether or not to apply batch normalization to the
continuous input
concat_cont_first: bool, Optional, default = True
concat_cont_first: bool, default = True
Boolean indicating whether the continuum columns will be
concatenated with the Embeddings and then passed through the
Resnet blocks (``True``) or, alternatively, will be concatenated
Expand All @@ -175,13 +175,13 @@ def __init__(
if ``mlp_hidden_dims`` is ``True``, this attribute will be an mlp
model that will receive:
i) the results of the concatenation of the embeddings and the
continuous columns -- if present -- and then passed it through
the ``dense_resnet`` (``concat_cont_first = True``), or
- the results of the concatenation of the embeddings and the
continuous columns -- if present -- and then passed it through
the ``dense_resnet`` (``concat_cont_first = True``), or
ii) the result of passing the embeddings through the ``dense_resnet``
and then concatenating the results with the continuous columns --
if present -- (``concat_cont_first = False``)
- the result of passing the embeddings through the ``dense_resnet``
and then concatenating the results with the continuous columns --
if present -- (``concat_cont_first = False``)
output_dim: `int`
The output dimension of the model. This is a required attribute
Expand Down
Loading

0 comments on commit 56bd75e

Please sign in to comment.