Skip to content

Commit

Permalink
Support dp test with txt (#2235)
Browse files Browse the repository at this point in the history
Support using a text file(E.G. `list.txt`) to specify the list of test
sets

list.txt
```
data_dir/sys.000
data_dir/sys.001
data_dir/sys.002
...
```
`/data_dir/sys.000` should be a directory with "type.raw".
Then one can use `-l list.txt` to replace `-s data_dir`.
  • Loading branch information
HuangJiameng authored Jan 11, 2023
1 parent 6154494 commit 6322609
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
10 changes: 9 additions & 1 deletion deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,21 @@ def main_parser() -> argparse.ArgumentParser:
type=str,
help="Frozen model file to import",
)
parser_tst.add_argument(
parser_tst_subgroup = parser_tst.add_mutually_exclusive_group()
parser_tst_subgroup.add_argument(
"-s",
"--system",
default=".",
type=str,
help="The system dir. Recursively detect systems in this directory",
)
parser_tst_subgroup.add_argument(
"-f",
"--datafile",
default=None,
type=str,
help="The path to file of test list.",
)
parser_tst.add_argument(
"-S", "--set-prefix", default="set", type=str, help="The set prefix"
)
Expand Down
11 changes: 10 additions & 1 deletion deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test(
*,
model: str,
system: str,
datafile: str,
set_prefix: str,
numb_test: int,
rand_seed: Optional[int],
Expand All @@ -39,6 +40,8 @@ def test(
path where model is stored
system : str
system directory
datafile : str
the path to the list of systems to test
set_prefix : str
string prefix of set
numb_test : int
Expand All @@ -57,7 +60,13 @@ def test(
RuntimeError
if no valid system was found
"""
all_sys = expand_sys_str(system)
if datafile is not None:
datalist = open(datafile, 'r')
all_sys = datalist.read().splitlines()
datalist.close()
else:
all_sys = expand_sys_str(system)

if len(all_sys) == 0:
raise RuntimeError("Did not find valid system")
err_coll = []
Expand Down

0 comments on commit 6322609

Please sign in to comment.