-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* WIP: Add config/trainer.py with TrainerConfig * Rename common.device -> common.accelerator, return 'gpu' not 'cuda' if torch.cuda.is_available * Fix config section in doc/api/index.rst * Import trainer and TrainerConfig in src/vak/config/__init__.py, add to __all__ * Add pytorch-lightning to intersphinx in doc/conf.py * Fix cross-ref in docstring in src/vak/prep/frame_classification/make_splits.py: :constant: -> :const: * Make lightning a dependency, instead of pytorch_lightning; import lightning.pytorch everywhere instead of pytorch_lightning as lightning -- trying to make it so we can resolve API correctly in docstrings * Fix in doc/api/index.rst: common.device -> common.accelerator * Finish writing TrainerConfig class * Add tests for TrainerConfig class * Add trainer sub-table to all configs in tests/data_for_tests/configs * Add trainer sub-table to all configs in doc/toml * Add trainer sub-table in config/valid-version-1.0.toml, rename -> valid-version-1.1.toml * Remove device key from top-level tables in config/valid-version-1.1.toml * Remove device key from top-level tables in tests/data_for_tests/configs * Remove 'device' key from configs in doc/toml * Add 'trainer' attribute to EvalConfig, an instance of TrainerConfig; remove 'device' attribute * Add 'trainer' attribute to PredictConfig, an instance of TrainerConfig; remove 'device' attribute * Add 'trainer' attribute to TrainConfig, an instance of TrainerConfig; remove 'device' attribute * Fix typo in docstring in src/vak/config/train.py * Add 'trainer' attribute to LearncurveConfig, an instance of TrainerConfig; remove 'device' attribute. Also clean up docstring, removing attributes that no longer exist * Remove device attribute from TrainConfig docstring * Fix VALID_TOML_PATH in config/validators.py -> 'valid-version-1.1.toml' * Fix how we instantiate TrainerConfig classes in from_config_dict method of EvalConfig/LearncurveConfig/PredictConfig/TrainConfig * Fix typo in src/vak/config/valid-version-1.1.toml: predictor -> predict * Fix unit tests after adding trainer attribute that is instance of TrainerConfig * Change src/vak/train/frame_classification.py to take trainer_config argument * Change src/vak/train/parametric_umap.py to take trainer_config argument * Change src/vak/train/train_.py to take trainer_config argument * Fix src/vak/cli/train.py to pass trainer_config.asdict() into vak.train.train_.train * Replace 'device' with 'trainer_config' in vak/eval * Fix cli.eval to pass trainer_config into eval.eval_.eval * Replace 'device' with 'trainer_config' in vak/predict * Fix cli.predict to pass trainer_config into predict.predict_.predict * Replace 'device' with 'trainer_config' in vak/learncurve * Fix cli.learncurve to pass trainer_config into learncurve.learncurve.learning_curve * Rename/replace 'device' fixture with 'trainer' fixture in tests/ * Use config.table.trainer attribute throughout tests, remove config.table.device attribute that no longer exists * Fix value for devices in fixtures/trainer.py: when device is 'cpu' trainer must be > 0 * Fix default devices value for when accelerator is cpu in TrainerConfig * Fix unit tests for TrainerConfig after fixing default devices for accelerator=cpu * Fix default value for 'devices' set to -1 in some unit tests where we over-ride config in toml file * fixup use config.table.trainer attribute throughout tests -- missed one place in tests/test_eval/ * Add back 'device' fixture so we can use it to test Model class * Fix unit tests in test_models/test_base.by that literally used device to put tensors on device, not to change a config * Fix assertion in tests/test_models/test_tweetynet.py, from where we switched to using lightning as the dependency * Fix test for DiceLoss, change trainer_type fixture back to device fixture
- Loading branch information
1 parent
1e6abe6
commit e74b92b
Showing
71 changed files
with
772 additions
and
375 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import torch | ||
|
||
|
||
def get_default() -> str: | ||
"""Get default `accelerator` for :class:`lightning.pytorch.Trainer`. | ||
Returns | ||
------- | ||
accelerator : str | ||
Will be ``'gpu'`` if :func:`torch.cuda.is_available` | ||
is ``True``, and ``'cpu'`` if not. | ||
""" | ||
if torch.cuda.is_available(): | ||
return "gpu" | ||
else: | ||
return "cpu" |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.