Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method to save Cadet sim as python file which can generate the Cadet sim again #10

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions cadet/cadet.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,32 @@ def load_json(self, filename, update=False):
else:
self.root = data

def save_as_python_script(self, filename: str, only_return_pythonic_representation=False):
if not filename.endswith(".py"):
raise Warning(f"The filename given to .save_as_python_script isn't a python file name.")

code_lines_list = [
"import numpy",
"from cadet import Cadet",
"",
"sim = Cadet()",
"root = sim.root",
]

code_lines_list = recursively_turn_dict_to_python_list(dictionary=self.root,
current_lines_list=code_lines_list,
prefix="root")

filename_for_reproduced_h5_file = filename.replace(".py", ".h5")
code_lines_list.append(f"sim.filename = '{filename_for_reproduced_h5_file}'")
code_lines_list.append("sim.save()")

if not only_return_pythonic_representation:
with open(filename, "w") as handle:
handle.writelines([line + "\n" for line in code_lines_list])
else:
return code_lines_list

def append(self, lock=False):
"This can only be used to write new keys to the system, this is faster than having to read the data before writing it"
if self.filename is not None:
Expand Down Expand Up @@ -347,3 +373,71 @@ def recursively_save(h5file, path, dic, func):
raise KeyError(f'Name conflict with upper and lower case entries for key "{path}{key}".')
else:
raise


def recursively_turn_dict_to_python_list(dictionary: dict, current_lines_list: list = None, prefix: str = None):
"""
Recursively turn a nested dictionary or addict.Dict into a list of Python code that
can generate the nested dictionary.

:param dictionary:
:param current_lines_list:
:param prefix_list:
:return: list of Python code lines
"""

def merge_to_absolute_key(prefix, key):
"""
Combine key and prefix to "prefix.key" except if there is no prefix, then return key
"""
if prefix is None:
return key
else:
return f"{prefix}.{key}"

def clean_up_key(absolute_key: str):
"""
Remove problematic phrases from key, such as blank "return"

:param absolute_key:
:return:
"""
absolute_key = absolute_key.replace(".return", "['return']")
return absolute_key

def get_pythonic_representation_of_value(value):
"""
Use repr() to get a pythonic representation of the value
and add "np." to "array" and "float64"

"""
value_representation = repr(value)
value_representation = value_representation.replace("array", "numpy.array")
value_representation = value_representation.replace("float64", "numpy.float64")
try:
eval(value_representation)
except NameError as e:
raise ValueError(
f"Encountered a value of '{value_representation}' that can't be directly reproduced in python.\n"
f"Please report this to the CADET-Python developers.") from e

return value_representation

if current_lines_list is None:
current_lines_list = []

for key in sorted(dictionary.keys()):
value = dictionary[key]

absolute_key = merge_to_absolute_key(prefix, key)

if type(value) in (dict, Dict):
current_lines_list = recursively_turn_dict_to_python_list(value, current_lines_list, prefix=absolute_key)
else:
value_representation = get_pythonic_representation_of_value(value)

absolute_key = clean_up_key(absolute_key)

current_lines_list.append(f"{absolute_key} = {value_representation}")

return current_lines_list
64 changes: 64 additions & 0 deletions tests/test_save_as_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import tempfile

import numpy as np
import pytest
from addict import Dict

from cadet import Cadet


@pytest.fixture
def temp_cadet_file():
"""
Create a new Cadet object for use in tests.
"""
model = Cadet()

with tempfile.NamedTemporaryFile() as temp:
model.filename = temp
yield model


def test_save_as_python(temp_cadet_file):
"""
Test that the Cadet class raises a KeyError exception when duplicate keys are set on it.
"""
# initialize "sim" variable to be overwritten by the exec lines later
sim = Cadet()

# Populate temp_cadet_file with all tricky cases currently known
temp_cadet_file.root.input.foo = 1
temp_cadet_file.root.input.bar.baryon = np.arange(10)
temp_cadet_file.root.input.bar.barometer = np.linspace(0, 10, 9)
temp_cadet_file.root.input.bar.init_q = np.array([], dtype=np.float64)
temp_cadet_file.root.input["return"].split_foobar = 1

code_lines = temp_cadet_file.save_as_python_script(filename="temp.py", only_return_pythonic_representation=True)

# remove code lines that save the file
code_lines = code_lines[:-2]

# populate "sim" variable using the generated code lines
for line in code_lines:
exec(line)

# test that "sim" is equal to "temp_cadet_file"
recursive_equality_check(sim.root, temp_cadet_file.root)


def recursive_equality_check(dict_a: dict, dict_b: dict):
assert dict_a.keys() == dict_b.keys()
for key in dict_a.keys():
value_a = dict_a[key]
value_b = dict_b[key]
if type(value_a) in (dict, Dict):
recursive_equality_check(value_a, value_b)
elif type(value_a) == np.ndarray:
np.testing.assert_array_equal(value_a, value_b)
else:
assert value_a == value_b
return True


if __name__ == "__main__":
pytest.main()
Loading