Skip to content

Commit

Permalink
fix: incorporate feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
leahaeusel committed Dec 23, 2024
1 parent 7e563d5 commit 3a0a7a0
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 50 deletions.
8 changes: 6 additions & 2 deletions queens/drivers/jobscript_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,12 @@ def get_read_in_jobscript_template(jobscript_template):
if Path(jobscript_template).is_file():
jobscript_template = read_file(jobscript_template)
except OSError:
# We assume that the string already holds the jobscript template contents
pass
_logger.debug(
"The provided jobscript template string is not a regular file so we assume "
"that it holds the read-in jobscript template. The jobscript template reads:\n"
"%s",
{jobscript_template},
)

elif isinstance(jobscript_template, Path):
if jobscript_template.is_file():
Expand Down
105 changes: 57 additions & 48 deletions tests/unit_tests/drivers/test_jobscript_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Unit tests for the jobscript driver."""

import os
from contextlib import nullcontext as does_not_raise

import numpy as np
import pytest
Expand Down Expand Up @@ -192,65 +193,45 @@ def fixture_args_init(
These arguments are meant for initialization with the default
constructor.
"""
# pylint: disable=unused-argument
args_init = locals()
args_init["input_templates"] = args_init["input_template"]
args_init.pop("input_template")
args_init = {
"parameters": parameters,
"jobscript_template": jobscript_template,
"executable": executable,
"input_templates": input_template,
"files_to_copy": files_to_copy,
"data_processor": data_processor,
"gradient_data_processor": gradient_data_processor,
"jobscript_file_name": jobscript_file_name,
"extra_options": extra_options.copy(),
}
return args_init


def assert_jobscript_driver_attributes(jobscript_driver, args_init):
def assert_jobscript_driver_attributes(jobscript_driver, args_init, extra_options):
"""Assert that the jobscript driver attributes are set correctly."""
args_init["files_to_copy"].append(args_init["input_templates"])
extra_options.update({"executable": args_init["executable"]})

assert jobscript_driver.parameters == args_init["parameters"]
assert jobscript_driver.input_templates == {"input_file": args_init["input_templates"]}
assert jobscript_driver.jobscript_template == args_init["jobscript_template"]
assert jobscript_driver.jobscript_options["executable"] == args_init["executable"]
assert jobscript_driver.files_to_copy == args_init["files_to_copy"]
assert jobscript_driver.data_processor == args_init["data_processor"]
assert jobscript_driver.gradient_data_processor == args_init["gradient_data_processor"]
assert jobscript_driver.jobscript_file_name == args_init["jobscript_file_name"]
assert jobscript_driver.jobscript_options.items() >= args_init["extra_options"].items()
assert jobscript_driver.jobscript_options == extra_options


def assert_subprocess_error(
parameters, raise_error_on_jobscript_failure, jobscript_driver, job_options
):
"""Assert that jobscript driver run raises an error if so intended."""
sample_dict = parameters.sample_as_dict(np.array([1, 2]))
sample = np.array(list(sample_dict.values()))

if raise_error_on_jobscript_failure:
with pytest.raises(SubprocessError):
jobscript_driver.run(
sample=sample,
job_id=job_options.job_id,
num_procs=job_options.num_procs,
experiment_dir=job_options.experiment_dir,
experiment_name=job_options.experiment_name,
)
else:
jobscript_driver.run(
sample=sample,
job_id=job_options.job_id,
num_procs=job_options.num_procs,
experiment_dir=job_options.experiment_dir,
experiment_name=job_options.experiment_name,
)


def test_init_from_jobscript_template_str(args_init):
def test_init_from_jobscript_template_str(args_init, extra_options):
"""Test initialization of the JobscriptDriver.
For this initialization, the jobscript template is provided in the
form of a string describing the jobscript template contents.
"""
driver = JobscriptDriver(**args_init)
assert_jobscript_driver_attributes(driver, args_init)
assert_jobscript_driver_attributes(driver, args_init, extra_options)


def test_init_from_jobscript_template_path(args_init, jobscript_template_path):
def test_init_from_jobscript_template_path(args_init, jobscript_template_path, extra_options):
"""Test initialization of the JobscriptDriver.
For this initialization, the jobscript template is provided in the
Expand All @@ -259,7 +240,7 @@ def test_init_from_jobscript_template_path(args_init, jobscript_template_path):
args_init_from_jobscript_template_path = args_init.copy()
args_init_from_jobscript_template_path["jobscript_template"] = jobscript_template_path
driver = JobscriptDriver(**args_init_from_jobscript_template_path)
assert_jobscript_driver_attributes(driver, args_init)
assert_jobscript_driver_attributes(driver, args_init, extra_options)


def test_multiple_input_files(jobscript_driver, job_options, injected_input_files, parameters):
Expand All @@ -286,9 +267,15 @@ def test_multiple_input_files(jobscript_driver, job_options, injected_input_file
assert value == str(injectable_options[key])


@pytest.mark.parametrize("raise_error_on_jobscript_failure", [True, False])
@pytest.mark.parametrize(
"raise_error_on_jobscript_failure, expectation",
[
(False, does_not_raise()),
(True, pytest.raises(SubprocessError)),
],
)
def test_error_in_jobscript_template(
parameters, input_template, job_options, raise_error_on_jobscript_failure
parameters, input_template, job_options, raise_error_on_jobscript_failure, expectation
):
"""Test for an error when the jobscript template has an error."""
jobscript_driver = JobscriptDriver(
Expand All @@ -298,14 +285,28 @@ def test_error_in_jobscript_template(
executable="",
raise_error_on_jobscript_failure=raise_error_on_jobscript_failure,
)
assert_subprocess_error(
parameters, raise_error_on_jobscript_failure, jobscript_driver, job_options
)
sample_dict = parameters.sample_as_dict(np.array([1, 2]))
sample = np.array(list(sample_dict.values()))

with expectation:
jobscript_driver.run(
sample=sample,
job_id=job_options.job_id,
num_procs=job_options.num_procs,
experiment_dir=job_options.experiment_dir,
experiment_name=job_options.experiment_name,
)

@pytest.mark.parametrize("raise_error_on_jobscript_failure", [True, False])

@pytest.mark.parametrize(
"raise_error_on_jobscript_failure, expectation",
[
(False, does_not_raise()),
(True, pytest.raises(SubprocessError)),
],
)
def test_nonzero_exit_code(
parameters, input_template, job_options, raise_error_on_jobscript_failure
parameters, input_template, job_options, raise_error_on_jobscript_failure, expectation
):
"""Test for an error when the jobscript exits with a code other than 0."""
jobscript_driver = JobscriptDriver(
Expand All @@ -315,9 +316,17 @@ def test_nonzero_exit_code(
executable="",
raise_error_on_jobscript_failure=raise_error_on_jobscript_failure,
)
assert_subprocess_error(
parameters, raise_error_on_jobscript_failure, jobscript_driver, job_options
)
sample_dict = parameters.sample_as_dict(np.array([1, 2]))
sample = np.array(list(sample_dict.values()))

with expectation:
jobscript_driver.run(
sample=sample,
job_id=job_options.job_id,
num_procs=job_options.num_procs,
experiment_dir=job_options.experiment_dir,
experiment_name=job_options.experiment_name,
)


def test_long_jobscript_template_str(parameters, input_template):
Expand Down

0 comments on commit 3a0a7a0

Please sign in to comment.