Skip to content

Commit

Permalink
Merge pull request #89 from grok-ai/develop
Browse files Browse the repository at this point in the history
Update nn-template to 0.3.0
  • Loading branch information
lucmos authored Aug 6, 2023
2 parents ef65a25 + 5ba0718 commit ccfbf1d
Show file tree
Hide file tree
Showing 12 changed files with 41 additions and 29 deletions.
12 changes: 10 additions & 2 deletions .github/workflows/test_suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ on:
- synchronize

env:
CACHE_NUMBER: 3 # increase to reset cache manually
CACHE_NUMBER: 4 # increase to reset cache manually
CONDA_ENV_FILE: 'env.yaml'
CONDA_ENV_NAME: 'project-test'
COOKIECUTTER_PROJECT_NAME: 'project-test'
Expand All @@ -24,7 +24,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ['3.8', '3.9']
python-version: ['3.8', '3.11']
include:
- os: ubuntu-20.04
label: linux-64
Expand Down Expand Up @@ -105,6 +105,14 @@ jobs:
cat ${{ env.CONDA_ENV_FILE }}
working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }}

# Install torch cpu-only
- name: Install torch cpu only
shell: bash -l {0}
run: |
sed -i '/nvidia\|cuda/d' ${{ env.CONDA_ENV_FILE }}
cat ${{ env.CONDA_ENV_FILE }}
working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }}

- name: Setup Mambaforge
uses: conda-incubator/setup-miniconda@v2
with:
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# NN Template

<p align="center">
<a href="https://github.com/grok-ai/nn-template/actions/workflows/test_suite.yml"><img alt="CI" src=https://img.shields.io/github/workflow/status/grok-ai/nn-template/Test%20Suite/main?label=main%20checks></a>
<a href="https://github.com/grok-ai/nn-template/actions/workflows/test_suite.yml"><img alt="CI" src=https://img.shields.io/github/workflow/status/grok-ai/nn-template/Test%20Suite/develop?label=develop%20checks></a>
<a href="https://grok-ai.github.io/nn-template"><img alt="Docs" src=https://img.shields.io/github/deployments/grok-ai/nn-template/github-pages?label=docs></a>
<a href="https://github.com/grok-ai/nn-template/actions/workflows/test_suite.yml"><img alt="CI" src=https://github.com/grok-ai/nn-template/actions/workflows/test_suite.yml/badge.svg?branch=main></a>
<a href="https://github.com/grok-ai/nn-template/actions/workflows/test_suite.yml"><img alt="CI" src=https://github.com/grok-ai/nn-template/actions/workflows/test_suite.yml/badge.svg?branch=develop></a>
<a href="https://github.com/grok-ai/nn-template/actions/workflows/publish_docs.yml/badge.svg"><img alt="Docs" src=https://github.com/grok-ai/nn-template/actions/workflows/publish_docs.yml/badge.svg></a>
<a href="https://pypi.org/project/nn-template-core/"><img alt="Release" src="https://img.shields.io/pypi/v/nn-template-core?label=nn-core"></a>
<a href="https://black.readthedocs.io/en/stable/"><img alt="Code style: black" src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>
</p>
Expand Down Expand Up @@ -34,8 +34,8 @@
Generic template to bootstrap your [PyTorch](https://pytorch.org/get-started/locally/) project,
read more in the [documentation](https://grok-ai.github.io/nn-template).

![nn-template-asciinema](https://s8.gifyu.com/images/optimized.gif)

[![asciicast](https://asciinema.org/a/475623.svg)](https://asciinema.org/a/475623)

## Get started

Expand All @@ -47,7 +47,7 @@ cookiecutter https://github.com/grok-ai/nn-template

<details>
<summary>Otherwise</summary>
Cookiecutter manages the setup stages and delivers to you a personalized ready to run project.
Cookiecutter manages the setup stages and delivers to you a personalized ready to run project.

Install it with:
<pre><code>pip install cookiecutter
Expand All @@ -56,7 +56,7 @@ Install it with:

More details in the [documentation](https://grok-ai.github.io/nn-template/latest/getting-started/generation/).

## Strengths
## Strengths

- **Actually works for [research](https://grok-ai.github.io/nn-template/latest/papers/)**!
- Guided setup to customize project bootstrapping;
Expand Down
4 changes: 2 additions & 2 deletions cookiecutter.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@
"package_name": "{{ cookiecutter.project_name.strip().lower().replace(' ', '-').replace('-', '_') }}",
"repository_url": "https://github.com/{{ cookiecutter.github_user }}/{{ cookiecutter.project_name.strip().lower().replace(' ', '-') }}",
"conda_env_name": "{{ cookiecutter.project_name.strip().lower().replace(' ', '-') }}",
"python_version": "3.9",
"__version": "0.2.3"
"python_version": "3.11",
"__version": "0.3.0"
}
2 changes: 1 addition & 1 deletion {{ cookiecutter.repository_name }}/conf/nn/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ data:
test:
- _target_: {{ cookiecutter.package_name }}.data.dataset.MyDataset

gpus: ${train.trainer.gpus}
accelerator: ${train.trainer.accelerator}

num_workers:
train: 8
Expand Down
4 changes: 2 additions & 2 deletions {{ cookiecutter.repository_name }}/conf/train/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ deterministic: False
# PyTorch Lightning Trainer https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
trainer:
fast_dev_run: False # Enable this for debug purposes
gpus: 1
accelerator: 'gpu'
devices: 1
precision: 32
max_epochs: 3
max_steps: 10000
accumulate_grad_batches: 1
num_sanity_val_steps: 2
gradient_clip_val: 10.0
val_check_interval: 1.0
Expand Down
5 changes: 2 additions & 3 deletions {{ cookiecutter.repository_name }}/env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ channels:

dependencies:
- python={{ cookiecutter.python_version }}
- pytorch==1.13.*
- pytorch-cuda=11.6
- pytorch=2.0.*
- torchvision
- torchaudio
- pytorch-cuda=11.8
- pip
- pip:
- -e .[dev]
8 changes: 4 additions & 4 deletions {{ cookiecutter.repository_name }}/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ package_dir=
=src
packages=find:
install_requires =
nn-template-core==0.2.*
nn-template-core==0.3.*

# Add project specific dependencies
# Stuff easy to break with updates
pytorch-lightning==1.7.*
torchmetrics==0.10.*
hydra-core==1.2.*
lightning==2.0.*
torchmetrics==1.0.*
hydra-core==1.3.*
wandb
streamlit
# hydra-joblib-launcher
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from functools import cached_property, partial
from pathlib import Path
from typing import List, Mapping, Optional, Sequence, Union
from typing import List, Mapping, Optional, Sequence

import hydra
import omegaconf
Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(
datasets: DictConfig,
num_workers: DictConfig,
batch_size: DictConfig,
gpus: Optional[Union[List[int], str, int]],
accelerator: str,
# example
val_percentage: float,
):
Expand All @@ -110,7 +110,7 @@ def __init__(
self.num_workers = num_workers
self.batch_size = batch_size
# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#gpus
self.pin_memory: bool = gpus is not None and str(gpus) != "0"
self.pin_memory: bool = accelerator is not None and str(accelerator) == "gpu"

self.train_dataset: Optional[Dataset] = None
self.val_datasets: Optional[Sequence[Dataset]] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def __init__(self, metadata: Optional[MetaData] = None, *args, **kwargs) -> None
self.metadata = metadata

# example
metric = torchmetrics.Accuracy()
metric = torchmetrics.Accuracy(
task="multiclass",
num_classes=len(metadata.class_vocab) if metadata is not None else None,
)
self.train_accuracy = metric.clone()
self.val_accuracy = metric.clone()
self.test_accuracy = metric.clone()
Expand Down Expand Up @@ -151,9 +154,11 @@ def main(cfg: omegaconf.DictConfig) -> None:
Args:
cfg: the hydra configuration
"""
module = cfg.nn.module
_: pl.LightningModule = hydra.utils.instantiate(
cfg.model,
optim=cfg.optim,
module,
optim=module.optimizer,
metadata=MetaData(class_vocab={str(i): i for i in range(10)}),
_recursive_=False,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def run(cfg: DictConfig) -> str:
if fast_dev_run:
pylogger.info(f"Debug mode <{cfg.train.trainer.fast_dev_run=}>. Forcing debugger friendly configuration!")
# Debuggers don't like GPUs nor multiprocessing
cfg.train.trainer.gpus = 0
cfg.train.trainer.accelerator = "cpu"
cfg.nn.data.num_workers.train = 0
cfg.nn.data.num_workers.val = 0
cfg.nn.data.num_workers.test = 0
Expand Down
2 changes: 1 addition & 1 deletion {{ cookiecutter.repository_name }}/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def cfg_simple_train(cfg: DictConfig) -> DictConfig:
cfg.core.tags = ["testing"]

# Disable gpus
cfg.train.trainer.gpus = 0
cfg.train.trainer.accelerator = "cpu"

# Disable logger
cfg.train.logging.logger.mode = "disabled"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_load_checkpoint(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig

checkpoint = NNCheckpointIO.load(path=checkpoint_path)

module = _load_state(cls=module_class, checkpoint=checkpoint, metadata=checkpoint["metadata"])
module = _load_state(cls=module_class, checkpoint=checkpoint, metadata=checkpoint["metadata"], strict=True)
assert module is not None
assert sum(p.numel() for p in module.parameters())

Expand Down

0 comments on commit ccfbf1d

Please sign in to comment.