From 056b925708d51b26d438f18c4f4f146ec2bf5051 Mon Sep 17 00:00:00 2001 From: Ed Chalstrey Date: Thu, 6 Oct 2022 10:12:22 +0100 Subject: [PATCH] Rename model arg in load_pretrained_model to model_selection (#341) * rename model arg in load_pretrained_model to model_selection * bumping scivision version to 0.3 due to slightly breaking change --- scivision/io/reader.py | 8 ++++---- setup.py | 2 +- tests/test_reader.py | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/scivision/io/reader.py b/scivision/io/reader.py index b44c3c76..b5216ad9 100644 --- a/scivision/io/reader.py +++ b/scivision/io/reader.py @@ -120,7 +120,7 @@ def load_pretrained_model( path: os.PathLike, branch: str = "main", allow_install: bool = False, - model: str = "default", + model_selection: str = "default", load_multiple: bool = False, *args, **kwargs, @@ -135,8 +135,8 @@ def load_pretrained_model( Specify the name of a github branch if loading from github. allow_install : bool, default = False Allow installation of remote package via pip. - model : str, default = default - Specify the name of the model if there is > 1. + model_selection : str, default = default + Specify the name of the model if there is > 1 in the model repo package. load_multiple : bool, default = False Modifies the return to be a list of scivision.PretrainedModel's. @@ -163,7 +163,7 @@ def load_pretrained_model( with file as config_file: stream = config_file.read() config = yaml.safe_load(stream) - config_list = _get_model_configs(config, load_multiple, model) + config_list = _get_model_configs(config, load_multiple, model_selection) loaded_models = [] for config in config_list: # make sure a model at least has an input to the function diff --git a/setup.py b/setup.py index bfbdb773..a10647f0 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ setup( name="scivision", - version="0.2.11", + version="0.3.0", description="scivision", author="The Alan Turing Institute", author_email="scivision@turing.ac.uk", diff --git a/tests/test_reader.py b/tests/test_reader.py index db769dfa..045e6f10 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -81,20 +81,20 @@ def test_load_pretrained_model_local(): def test_load_named_pretrained_model_local(): """Test that scivision can load a specific model from the given scivision.yml.""" - assert type(load_pretrained_model('tests/test_model_scivision.yml', allow_install=True, model='ImageNetModel')) == wrapper.PretrainedModel - assert type(load_pretrained_model('tests/test_multiple_models_scivision.yml', allow_install=True, model='ImageNetModel')) == wrapper.PretrainedModel + assert type(load_pretrained_model('tests/test_model_scivision.yml', allow_install=True, model_selection='ImageNetModel')) == wrapper.PretrainedModel + assert type(load_pretrained_model('tests/test_multiple_models_scivision.yml', allow_install=True, model_selection='ImageNetModel')) == wrapper.PretrainedModel def test_load_wrong_model_name_raises_value_error(): """Test that a value error is raised when a model name is specified that doesn't match the model in the config.""" with pytest.raises(ValueError): - load_pretrained_model('tests/test_model_scivision.yml', allow_install=True, model='FakeModel') + load_pretrained_model('tests/test_model_scivision.yml', allow_install=True, model_selection='FakeModel') def test_load_wrong_model_name_raises_value_error_config_has_multiple_models(): """Test that a value error is raised when a model name is specified that doesn't match one of the models in the config.""" with pytest.raises(ValueError): - load_pretrained_model('tests/test_multiple_models_scivision.yml', allow_install=True, model='FakeModel') + load_pretrained_model('tests/test_multiple_models_scivision.yml', allow_install=True, model_selection='FakeModel') def test_load_multiple_models():