diff --git a/deepmd/entrypoints/main.py b/deepmd/entrypoints/main.py index 35273bf4f7..5a674fe3d7 100644 --- a/deepmd/entrypoints/main.py +++ b/deepmd/entrypoints/main.py @@ -264,13 +264,21 @@ def main_parser() -> argparse.ArgumentParser: type=str, help="Frozen model file to import", ) - parser_tst.add_argument( + parser_tst_subgroup = parser_tst.add_mutually_exclusive_group() + parser_tst_subgroup.add_argument( "-s", "--system", default=".", type=str, help="The system dir. Recursively detect systems in this directory", ) + parser_tst_subgroup.add_argument( + "-f", + "--datafile", + default=None, + type=str, + help="The path to file of test list.", + ) parser_tst.add_argument( "-S", "--set-prefix", default="set", type=str, help="The set prefix" ) diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index 76057419f3..a4feaa88f6 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -23,6 +23,7 @@ def test( *, model: str, system: str, + datafile: str, set_prefix: str, numb_test: int, rand_seed: Optional[int], @@ -39,6 +40,8 @@ def test( path where model is stored system : str system directory + datafile : str + the path to the list of systems to test set_prefix : str string prefix of set numb_test : int @@ -57,7 +60,13 @@ def test( RuntimeError if no valid system was found """ - all_sys = expand_sys_str(system) + if datafile is not None: + datalist = open(datafile, 'r') + all_sys = datalist.read().splitlines() + datalist.close() + else: + all_sys = expand_sys_str(system) + if len(all_sys) == 0: raise RuntimeError("Did not find valid system") err_coll = []