Skip to content

Commit

Permalink
get tests working again
Browse files Browse the repository at this point in the history
  • Loading branch information
sneakers-the-rat committed Oct 18, 2024
1 parent fcd8b65 commit 1701ef9
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 20 deletions.
4 changes: 2 additions & 2 deletions src/numpydantic/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
# pdb.set_trace()
interface_cls = Interface.match_output(value)
array = interface_cls.to_json(value, info)
array = postprocess_json(array, info)
array = postprocess_json(array, info, interface_cls)
return array


def postprocess_json(
array: Union[dict, list], info: SerializationInfo
array: Union[dict, list], info: SerializationInfo, interface_cls: type[Interface]
) -> Union[dict, list]:
"""
Modify json after dumping from an interface
Expand Down
8 changes: 4 additions & 4 deletions tests/test_interface/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,15 @@ def all_passing_cases_instance(all_passing_cases, tmp_output_dir_func):
for p in DTYPE_AND_INTERFACE_CASES_PASSING
)
)
def dtype_by_interface(request):
def all_passing_cases(request):
"""
Tests for all dtypes by all interfaces
"""
return request.param


@pytest.fixture()
def dtype_by_interface_instance(dtype_by_interface, tmp_output_dir_func):
array = dtype_by_interface.array(path=tmp_output_dir_func)
instance = dtype_by_interface.model(array=array)
def dtype_by_interface_instance(all_passing_cases, tmp_output_dir_func):
array = all_passing_cases.array(path=tmp_output_dir_func)
instance = all_passing_cases.model(array=array)
return instance
9 changes: 3 additions & 6 deletions tests/test_interface/test_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ class MyModel(BaseModel):
(H5Proxy(file="test_file.h5", path="/subpath", field="sup"), True),
(H5Proxy(file="test_file.h5", path="/subpath"), False),
(H5Proxy(file="different_file.h5", path="/subpath"), False),
(("different_file.h5", "/subpath", "sup"), ValueError),
("not even a proxy-like thing", ValueError),
(("different_file.h5", "/subpath", "sup"), False),
("not even a proxy-like thing", False),
],
)
def test_proxy_eq(comparison, valid):
Expand All @@ -232,8 +232,5 @@ def test_proxy_eq(comparison, valid):
proxy_a = H5Proxy(file="test_file.h5", path="/subpath", field="sup")
if valid is True:
assert proxy_a == comparison
elif valid is False:
assert proxy_a != comparison
else:
with pytest.raises(valid):
assert proxy_a == comparison
assert proxy_a != comparison
16 changes: 8 additions & 8 deletions tests/test_interface/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@ def test_interface_dump_json(dtype_by_interface_instance):


@pytest.mark.serialization
def test_interface_roundtrip_json(dtype_by_interface, tmp_output_dir_func):
def test_interface_roundtrip_json(all_passing_cases, tmp_output_dir_func):
"""
All interfaces should be able to roundtrip to and from json
"""
if "subclass" in dtype_by_interface.id.lower():
if "subclass" in all_passing_cases.id.lower():
pytest.xfail()

array = dtype_by_interface.array(path=tmp_output_dir_func)
case = dtype_by_interface.model(array=array)
array = all_passing_cases.array(path=tmp_output_dir_func)
case = all_passing_cases.model(array=array)

dumped_json = case.model_dump_json(round_trip=True)
model = case.model_validate_json(dumped_json)
Expand All @@ -123,16 +123,16 @@ def test_interface_mark_interface(an_interface):
@pytest.mark.serialization
@pytest.mark.parametrize("valid", [True, False])
@pytest.mark.filterwarnings("ignore:Mismatch between serialized mark")
def test_interface_mark_roundtrip(dtype_by_interface, valid, tmp_output_dir_func):
def test_interface_mark_roundtrip(all_passing_cases, valid, tmp_output_dir_func):
"""
All interfaces should be able to roundtrip with the marked interface,
and a mismatch should raise a warning and attempt to proceed
"""
if "subclass" in dtype_by_interface.id.lower():
if "subclass" in all_passing_cases.id.lower():
pytest.xfail()

array = dtype_by_interface.array(path=tmp_output_dir_func)
case = dtype_by_interface.model(array=array)
array = all_passing_cases.array(path=tmp_output_dir_func)
case = all_passing_cases.model(array=array)

dumped_json = case.model_dump_json(
round_trip=True, context={"mark_interface": True}
Expand Down

0 comments on commit 1701ef9

Please sign in to comment.