Skip to content

Commit

Permalink
Rename model arg in load_pretrained_model to model_selection (#341)
Browse files Browse the repository at this point in the history
* rename model arg in load_pretrained_model to model_selection

* bumping scivision version to 0.3 due to slightly breaking change
  • Loading branch information
edwardchalstrey1 authored Oct 6, 2022
1 parent 1864523 commit 056b925
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
8 changes: 4 additions & 4 deletions scivision/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def load_pretrained_model(
path: os.PathLike,
branch: str = "main",
allow_install: bool = False,
model: str = "default",
model_selection: str = "default",
load_multiple: bool = False,
*args,
**kwargs,
Expand All @@ -135,8 +135,8 @@ def load_pretrained_model(
Specify the name of a github branch if loading from github.
allow_install : bool, default = False
Allow installation of remote package via pip.
model : str, default = default
Specify the name of the model if there is > 1.
model_selection : str, default = default
Specify the name of the model if there is > 1 in the model repo package.
load_multiple : bool, default = False
Modifies the return to be a list of scivision.PretrainedModel's.
Expand All @@ -163,7 +163,7 @@ def load_pretrained_model(
with file as config_file:
stream = config_file.read()
config = yaml.safe_load(stream)
config_list = _get_model_configs(config, load_multiple, model)
config_list = _get_model_configs(config, load_multiple, model_selection)
loaded_models = []
for config in config_list:
# make sure a model at least has an input to the function
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setup(
name="scivision",
version="0.2.11",
version="0.3.0",
description="scivision",
author="The Alan Turing Institute",
author_email="[email protected]",
Expand Down
8 changes: 4 additions & 4 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,20 @@ def test_load_pretrained_model_local():

def test_load_named_pretrained_model_local():
"""Test that scivision can load a specific model from the given scivision.yml."""
assert type(load_pretrained_model('tests/test_model_scivision.yml', allow_install=True, model='ImageNetModel')) == wrapper.PretrainedModel
assert type(load_pretrained_model('tests/test_multiple_models_scivision.yml', allow_install=True, model='ImageNetModel')) == wrapper.PretrainedModel
assert type(load_pretrained_model('tests/test_model_scivision.yml', allow_install=True, model_selection='ImageNetModel')) == wrapper.PretrainedModel
assert type(load_pretrained_model('tests/test_multiple_models_scivision.yml', allow_install=True, model_selection='ImageNetModel')) == wrapper.PretrainedModel


def test_load_wrong_model_name_raises_value_error():
"""Test that a value error is raised when a model name is specified that doesn't match the model in the config."""
with pytest.raises(ValueError):
load_pretrained_model('tests/test_model_scivision.yml', allow_install=True, model='FakeModel')
load_pretrained_model('tests/test_model_scivision.yml', allow_install=True, model_selection='FakeModel')


def test_load_wrong_model_name_raises_value_error_config_has_multiple_models():
"""Test that a value error is raised when a model name is specified that doesn't match one of the models in the config."""
with pytest.raises(ValueError):
load_pretrained_model('tests/test_multiple_models_scivision.yml', allow_install=True, model='FakeModel')
load_pretrained_model('tests/test_multiple_models_scivision.yml', allow_install=True, model_selection='FakeModel')


def test_load_multiple_models():
Expand Down

0 comments on commit 056b925

Please sign in to comment.