Skip to content

Commit

Permalink
Update doc
Browse files Browse the repository at this point in the history
  • Loading branch information
pete-machine committed Jul 12, 2023
1 parent 86660f9 commit 0109bb6
Show file tree
Hide file tree
Showing 7 changed files with 320 additions and 192 deletions.
134 changes: 103 additions & 31 deletions README.ipynb
Original file line number Diff line number Diff line change
@@ -1,20 +1,5 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# TorchBricks\n",
"\n",
"[![codecov](https://codecov.io/gh/PeteHeine/torchbricks/branch/main/graph/badge.svg?token=torchbricks_token_here)](https://codecov.io/gh/PeteHeine/torchbricks)\n",
"[![CI](https://github.com/PeteHeine/torchbricks/actions/workflows/main.yml/badge.svg)](https://github.com/PeteHeine/torchbricks/actions/workflows/main.yml)\n",
"\n",
"TorchBricks builds pytorch models using small reuseable and decoupled parts - we call them bricks.\n",
"\n",
"The concept is simple and flexible and allows you to more easily combine, add more or swap out parts of the model (preprocessor, backbone, neck, head or post-processor), change the task or extend it with multiple tasks.\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down Expand Up @@ -123,7 +108,8 @@
"outputs": [],
"source": [
"brick_collection = BrickCollection(bricks)\n",
"batch_images = torch.rand((2, 3, 100, 200))\n",
"batch_size=2\n",
"batch_images = torch.rand((batch_size, 3, 100, 200))\n",
"named_outputs = brick_collection(named_inputs={'raw': batch_images}, stage=Stage.INFERENCE)\n",
"print(named_outputs.keys())"
]
Expand Down Expand Up @@ -172,14 +158,17 @@
"metadata": {},
"outputs": [],
"source": [
"num_classes = 3\n",
"bricks = {\n",
" 'preprocessor': BrickNotTrainable(PreprocessorDummy(), input_names=['raw'], output_names=['processed'], alive_stages=\"all\"),\n",
" 'backbone': BrickTrainable(TinyModel(n_channels=3, n_features=10), input_names=['processed'], output_names=['embedding'], alive_stages=\"all\"),\n",
" 'head': BrickTrainable(ClassifierDummy(num_classes=3, in_features=10), input_names=['embedding'], output_names=['logits', 'softmaxed'], \n",
" alive_stages=\"all\"),\n",
" 'backbone': BrickTrainable(TinyModel(n_channels=num_classes, n_features=10), input_names=['processed'], output_names=['embedding'], \n",
" alive_stages=\"all\"),\n",
" 'head': BrickTrainable(ClassifierDummy(num_classes=num_classes, in_features=10), input_names=['embedding'], \n",
" output_names=['logits', 'softmaxed'], alive_stages=\"all\"),\n",
" 'loss': BrickLoss(model=nn.CrossEntropyLoss(), input_names=['logits', 'targets'], output_names=['loss_ce'], \n",
" alive_stages=[Stage.TRAIN, Stage.VALIDATION, Stage.TEST], loss_output_names=\"all\")\n",
"}"
"}\n",
"brick_collection = BrickCollection(bricks)"
]
},
{
Expand All @@ -195,15 +184,34 @@
"\n",
"Another advantages is that model have different input requirements for different stages.\n",
"\n",
"For `Stage.INFERENCE` and `Stage.EXPROT` stages, the model only requires the `raw` tensor as input. \n",
"\n",
"While for `Stage.TRAIN`, `Stage.VALIDATION` and `Stage.TEST` stages, the model requires both `raw` and `targets` input tensors.\n",
"For `Stage.INFERENCE` and `Stage.EXPROT` stages, the model only requires the `raw` tensor as input. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"named_outputs_without_loss = brick_collection(named_inputs={'raw': batch_images}, stage=Stage.INFERENCE) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"\n",
"```python\n",
" named_outputs_without_loss = brick_collection(named_inputs={'raw': batch_images}, stage=Stage.INFERENCE)\n",
" named_outputs_with_loss = brick_collection(named_inputs={'raw': batch_images, \"targets\": torch.ones((1,3))}, stage=Stage.TRAIN)\n",
"```"
"For `Stage.TRAIN`, `Stage.VALIDATION` and `Stage.TEST` stages, the model requires both `raw` and `targets` input tensors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"named_outputs_with_loss = brick_collection(named_inputs={'raw': batch_images, \"targets\": torch.ones((batch_size,3))}, stage=Stage.TRAIN)"
]
},
{
Expand Down Expand Up @@ -243,7 +251,7 @@
"}\n",
"\n",
"brick_collection = BrickCollection(bricks)\n",
"named_inputs = {\"raw\": batch_images, \"targets\": torch.ones((2), dtype=torch.int64)}\n",
"named_inputs = {\"raw\": batch_images, \"targets\": torch.ones((batch_size), dtype=torch.int64)}\n",
"named_outputs = brick_collection(named_inputs=named_inputs, stage=Stage.TRAIN)\n",
"named_outputs = brick_collection(named_inputs=named_inputs, stage=Stage.TRAIN)\n",
"named_outputs = brick_collection(named_inputs=named_inputs, stage=Stage.TRAIN)\n",
Expand Down Expand Up @@ -305,9 +313,9 @@
"\n",
"Missing sections:\n",
"\n",
"- [ ] Acts as a nn.Module\n",
"- [x] Export as ONNX\n",
"- [x] Acts as a nn.Module\n",
"- [ ] Acts as a dictionary - Nested brick collection\n",
"- [ ] Export as ONNX\n",
"- [ ] Training with Pytorch lightning\n",
"- [ ] Pass all inputs as a dictionary `input_names='all'`\n",
"- [ ] Using stage inside module\n",
Expand All @@ -320,6 +328,70 @@
"`BrickMetricSingle` and `BrickCollection`) to both ensure sensible defaults and to show the intend of each brick. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Brick features: Export as ONNX\n",
"To export a brick collection as onnx we provide the `export_bricks_as_onnx`-function. \n",
"\n",
"Pass an example input (`named_input`) to trace a brick collection.\n",
"Set `dynamic_batch_size=True` to support any batch size inputs and here we explicitly set `stage=Stage.EXPORT` - this is also \n",
"the default."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from torchbricks.brick_utils import export_bricks_as_onnx\n",
"path_onnx = Path(\"build/readme_model.onnx\")\n",
"export_bricks_as_onnx(path_onnx=path_onnx, \n",
" brick_collection=brick_collection, \n",
" named_inputs=named_inputs, \n",
" dynamic_batch_size=True, \n",
" stage=Stage.EXPORT)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Brick features: Act as a nn.Module\n",
"A brick collection acts as a 'nn.Module' mean we can do the following: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Move to specify device (CPU/GPU) or precision to automatically move model parameters\n",
"brick_collection_half = brick_collection.to(torch.float16)\n",
"\n",
"\n",
"# Save model parameters\n",
"path_model = Path(\"build/readme_model.pt\")\n",
"torch.save(brick_collection_half.state_dict(), path_model)\n",
"\n",
"# Load model parameters\n",
"brick_collection_half.load_state_dict(torch.load(path_model))\n",
"\n",
"# Access parameters\n",
"brick_collection_half.named_parameters()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down Expand Up @@ -448,7 +520,7 @@
"- [x] Add typeguard\n",
"- [x] Allow a brick to receive all named_inputs and add a test for it.\n",
"- [x] Fix the release process. It should be as simple as running `make release`.\n",
"- [ ] Add onnx export example to the README.md\n",
"- [x] Add onnx export example to the README.md\n",
"- [ ] Make DAG like functionality to check if a inputs and outputs works for all model stages.\n",
"- [ ] Use pymy, pyright or pyre to do static code checks. \n",
"- [ ] Decide: Add stage as an internal state and not in the forward pass:\n",
Expand Down
85 changes: 62 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,6 @@ jupyter:
-->

# TorchBricks

[![codecov](https://codecov.io/gh/PeteHeine/torchbricks/branch/main/graph/badge.svg?token=torchbricks_token_here)](https://codecov.io/gh/PeteHeine/torchbricks)
[![CI](https://github.com/PeteHeine/torchbricks/actions/workflows/main.yml/badge.svg)](https://github.com/PeteHeine/torchbricks/actions/workflows/main.yml)

TorchBricks builds pytorch models using small reuseable and decoupled parts - we call them bricks.

The concept is simple and flexible and allows you to more easily combine, add more or swap out parts of the model (preprocessor, backbone, neck, head or post-processor), change the task or extend it with multiple tasks.


<!-- #region -->

## Install it with pip
Expand Down Expand Up @@ -98,7 +88,8 @@ returns a dictionary with both intermediated and resulting tensors.

```python
brick_collection = BrickCollection(bricks)
batch_images = torch.rand((2, 3, 100, 200))
batch_size=2
batch_images = torch.rand((batch_size, 3, 100, 200))
named_outputs = brick_collection(named_inputs={'raw': batch_images}, stage=Stage.INFERENCE)
print(named_outputs.keys())
```
Expand Down Expand Up @@ -131,17 +122,19 @@ In above example this is not particular interesting - because preprocessor, back
So we will demonstrate by adding a loss brick (`BrickLoss`) and specifying `alive_stages` for each brick.

```python
num_classes = 3
bricks = {
'preprocessor': BrickNotTrainable(PreprocessorDummy(), input_names=['raw'], output_names=['processed'], alive_stages="all"),
'backbone': BrickTrainable(TinyModel(n_channels=3, n_features=10), input_names=['processed'], output_names=['embedding'], alive_stages="all"),
'head': BrickTrainable(ClassifierDummy(num_classes=3, in_features=10), input_names=['embedding'], output_names=['logits', 'softmaxed'],
alive_stages="all"),
'backbone': BrickTrainable(TinyModel(n_channels=num_classes, n_features=10), input_names=['processed'], output_names=['embedding'],
alive_stages="all"),
'head': BrickTrainable(ClassifierDummy(num_classes=num_classes, in_features=10), input_names=['embedding'],
output_names=['logits', 'softmaxed'], alive_stages="all"),
'loss': BrickLoss(model=nn.CrossEntropyLoss(), input_names=['logits', 'targets'], output_names=['loss_ce'],
alive_stages=[Stage.TRAIN, Stage.VALIDATION, Stage.TEST], loss_output_names="all")
}
brick_collection = BrickCollection(bricks)
```

<!-- #region -->
We set `preprocessor`, `backbone` and `head` to be alive on all stages `alive_stages="all"` - this is the default behavior and
similar to before.

Expand All @@ -152,14 +145,19 @@ Another advantages is that model have different input requirements for different

For `Stage.INFERENCE` and `Stage.EXPROT` stages, the model only requires the `raw` tensor as input.

While for `Stage.TRAIN`, `Stage.VALIDATION` and `Stage.TEST` stages, the model requires both `raw` and `targets` input tensors.
```python
named_outputs_without_loss = brick_collection(named_inputs={'raw': batch_images}, stage=Stage.INFERENCE)
```

<!-- #region -->


For `Stage.TRAIN`, `Stage.VALIDATION` and `Stage.TEST` stages, the model requires both `raw` and `targets` input tensors.
<!-- #endregion -->

```python
named_outputs_without_loss = brick_collection(named_inputs={'raw': batch_images}, stage=Stage.INFERENCE)
named_outputs_with_loss = brick_collection(named_inputs={'raw': batch_images, "targets": torch.ones((1,3))}, stage=Stage.TRAIN)
named_outputs_with_loss = brick_collection(named_inputs={'raw': batch_images, "targets": torch.ones((batch_size,3))}, stage=Stage.TRAIN)
```
<!-- #endregion -->

## Bricks for model training
We are not creating a training framework, but to easily use the brick collection in your favorite training framework or custom
Expand Down Expand Up @@ -187,7 +185,7 @@ bricks = {
}

brick_collection = BrickCollection(bricks)
named_inputs = {"raw": batch_images, "targets": torch.ones((2), dtype=torch.int64)}
named_inputs = {"raw": batch_images, "targets": torch.ones((batch_size), dtype=torch.int64)}
named_outputs = brick_collection(named_inputs=named_inputs, stage=Stage.TRAIN)
named_outputs = brick_collection(named_inputs=named_inputs, stage=Stage.TRAIN)
named_outputs = brick_collection(named_inputs=named_inputs, stage=Stage.TRAIN)
Expand Down Expand Up @@ -234,9 +232,9 @@ Including metrics and losses with the model.

Missing sections:

- [ ] Acts as a nn.Module
- [x] Export as ONNX
- [x] Acts as a nn.Module
- [ ] Acts as a dictionary - Nested brick collection
- [ ] Export as ONNX
- [ ] Training with Pytorch lightning
- [ ] Pass all inputs as a dictionary `input_names='all'`
- [ ] Using stage inside module
Expand All @@ -249,6 +247,47 @@ but instead we recommend using our pre-configured brick modules (`BrickLoss`, `B
`BrickMetricSingle` and `BrickCollection`) to both ensure sensible defaults and to show the intend of each brick.


### Brick features: Export as ONNX
To export a brick collection as onnx we provide the `export_bricks_as_onnx`-function.

Pass an example input (`named_input`) to trace a brick collection.
Set `dynamic_batch_size=True` to support any batch size inputs and here we explicitly set `stage=Stage.EXPORT` - this is also
the default.

```python
from pathlib import Path
from torchbricks.brick_utils import export_bricks_as_onnx
path_onnx = Path("build/readme_model.onnx")
export_bricks_as_onnx(path_onnx=path_onnx,
brick_collection=brick_collection,
named_inputs=named_inputs,
dynamic_batch_size=True,
stage=Stage.EXPORT)
```

### Brick features: Act as a nn.Module
A brick collection acts as a 'nn.Module' mean we can do the following:

```python
# Move to specify device (CPU/GPU) or precision to automatically move model parameters
brick_collection_half = brick_collection.to(torch.float16)


# Save model parameters
path_model = Path("build/readme_model.pt")
torch.save(brick_collection_half.state_dict(), path_model)

# Load model parameters
brick_collection_half.load_state_dict(torch.load(path_model))

# Access parameters
brick_collection_half.named_parameters()
```





### Bag of bricks - reusable bricks modules
Note also in above example we use bag-of-bricks to import commonly used `nn.Module`s

Expand Down Expand Up @@ -346,7 +385,7 @@ MISSING
- [x] Add typeguard
- [x] Allow a brick to receive all named_inputs and add a test for it.
- [x] Fix the release process. It should be as simple as running `make release`.
- [ ] Add onnx export example to the README.md
- [x] Add onnx export example to the README.md
- [ ] Make DAG like functionality to check if a inputs and outputs works for all model stages.
- [ ] Use pymy, pyright or pyre to do static code checks.
- [ ] Decide: Add stage as an internal state and not in the forward pass:
Expand Down
49 changes: 49 additions & 0 deletions src/torchbricks/brick_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from pathlib import Path
from typing import Any, Dict

import torch
from torch import nn
from typeguard import typechecked

from torchbricks.bricks import BrickCollection, Stage


@typechecked
class _OnnxExportAdaptor(nn.Module):
def __init__(self, model: nn.Module, stage: Stage) -> None:
super().__init__()
self.model = model
self.stage = stage

def forward(self, named_inputs: Dict[str, Any]):
named_outputs = self.model.forward(named_inputs=named_inputs, stage=self.stage, return_inputs=False)
return named_outputs


@typechecked
def export_bricks_as_onnx(path_onnx: Path,
brick_collection: BrickCollection,
named_inputs: Dict[str, torch.Tensor],
dynamic_batch_size: bool,
stage: Stage = Stage.EXPORT,
**onnx_export_kwargs):

outputs = brick_collection(named_inputs=named_inputs, stage=stage, return_inputs=False)
onnx_exportable = _OnnxExportAdaptor(model=brick_collection, stage=stage)
output_names = list(outputs)
input_names = list(named_inputs)

if dynamic_batch_size:
if 'dynamic_axes' in onnx_export_kwargs:
raise ValueError("Setting both 'dynamic_batch_size==True' and defining 'dynamic_axes' in 'onnx_export_kwargs' is not allowed. ")
io_names = input_names + output_names
dynamic_axes = {io_name: {0: 'batch_size'} for io_name in io_names}
onnx_export_kwargs['dynamic_axes'] = dynamic_axes

torch.onnx.export(model=onnx_exportable,
args=({'named_inputs': named_inputs}, ),
f=str(path_onnx),
verbose=True,
input_names=input_names,
output_names=output_names,
**onnx_export_kwargs)
Loading

0 comments on commit 0109bb6

Please sign in to comment.