diff --git a/tests/nested_dask/test_accessor.py b/tests/nested_dask/test_accessor.py index 8cc7603..2d89adf 100644 --- a/tests/nested_dask/test_accessor.py +++ b/tests/nested_dask/test_accessor.py @@ -1,3 +1,4 @@ +import nested_dask as nd import pandas as pd import pyarrow as pa import pytest @@ -19,75 +20,90 @@ def test_fields(test_dataset): assert test_dataset.nested.nest.fields == ["t", "flux", "band"] -def test_to_flat(test_dataset): +def test_to_flat(): """test the to_flat function""" - flat_ztf = test_dataset.nested.nest.to_flat() + nf = nd.datasets.generate_data(10, 100, npartitions=2, seed=1) + + flat_nf = nf.nested.nest.to_flat() # check dtypes - assert flat_ztf.dtypes["t"] == pd.ArrowDtype(pa.float64()) - assert flat_ztf.dtypes["flux"] == pd.ArrowDtype(pa.float64()) - assert flat_ztf.dtypes["band"] == pd.ArrowDtype(pa.large_string()) + assert flat_nf.dtypes["t"] == pd.ArrowDtype(pa.float64()) + assert flat_nf.dtypes["flux"] == pd.ArrowDtype(pa.float64()) + assert flat_nf.dtypes["band"] == pd.ArrowDtype(pa.string()) # Make sure we retain all rows - assert len(flat_ztf.loc[1]) == 500 + assert len(flat_nf.loc[1]) == 100 + + one_row = flat_nf.compute().iloc[0] - one_row = flat_ztf.loc[1].compute().iloc[1] - assert pytest.approx(one_row["t"], 0.01) == 5.4584 - assert pytest.approx(one_row["flux"], 0.01) == 84.1573 + assert pytest.approx(one_row["t"], 0.01) == 16.0149 + assert pytest.approx(one_row["flux"], 0.01) == 51.2061 assert one_row["band"] == "r" -def test_to_flat_with_fields(test_dataset): +def test_to_flat_with_fields(): """test the to_flat function""" - flat_ztf = test_dataset.nested.nest.to_flat(fields=["t", "flux"]) + nf = nd.datasets.generate_data(10, 100, npartitions=2, seed=1) + + flat_nf = nf.nested.nest.to_flat(fields=["t", "flux"]) + + assert "band" not in flat_nf.columns # check dtypes - assert flat_ztf.dtypes["t"] == pd.ArrowDtype(pa.float64()) - assert flat_ztf.dtypes["flux"] == pd.ArrowDtype(pa.float64()) + assert flat_nf.dtypes["t"] == pd.ArrowDtype(pa.float64()) + assert flat_nf.dtypes["flux"] == pd.ArrowDtype(pa.float64()) # Make sure we retain all rows - assert len(flat_ztf.loc[1]) == 500 + assert len(flat_nf.loc[1]) == 100 - one_row = flat_ztf.loc[1].compute().iloc[1] - assert pytest.approx(one_row["t"], 0.01) == 5.4584 - assert pytest.approx(one_row["flux"], 0.01) == 84.1573 + one_row = flat_nf.compute().iloc[0] + assert pytest.approx(one_row["t"], 0.01) == 16.0149 + assert pytest.approx(one_row["flux"], 0.01) == 51.2061 -def test_to_lists(test_dataset): + +def test_to_lists(): """test the to_lists function""" - list_ztf = test_dataset.nested.nest.to_lists() + + nf = nd.datasets.generate_data(10, 100, npartitions=2, seed=1) + list_nf = nf.nested.nest.to_lists() # check dtypes - assert list_ztf.dtypes["t"] == pd.ArrowDtype(pa.list_(pa.float64())) - assert list_ztf.dtypes["flux"] == pd.ArrowDtype(pa.list_(pa.float64())) - assert list_ztf.dtypes["band"] == pd.ArrowDtype(pa.list_(pa.large_string())) + assert list_nf.dtypes["t"] == pd.ArrowDtype(pa.list_(pa.float64())) + assert list_nf.dtypes["flux"] == pd.ArrowDtype(pa.list_(pa.float64())) + assert list_nf.dtypes["band"] == pd.ArrowDtype(pa.list_(pa.string())) # Make sure we have a single row for an id - assert len(list_ztf.loc[1]) == 1 + assert len(list_nf.loc[1]) == 1 # Make sure we retain all rows -- double loc for speed and pandas get_item - assert len(list_ztf.loc[1].compute().loc[1]["t"]) == 500 + assert len(list_nf.loc[1].compute().loc[1]["t"]) == 100 + one_row = list_nf.compute().iloc[1] # spot-check values - assert pytest.approx(list_ztf.loc[1].compute().loc[1]["t"][0], 0.01) == 7.5690279 - assert pytest.approx(list_ztf.loc[1].compute().loc[1]["flux"][0], 0.01) == 79.6886 - assert list_ztf.loc[1].compute().loc[1]["band"][0] == "g" + assert pytest.approx(one_row["t"][0], 0.01) == 19.3652 + assert pytest.approx(one_row["flux"][0], 0.01) == 61.7461 + assert one_row["band"][0] == "g" -def test_to_lists_with_fields(test_dataset): +def test_to_lists_with_fields(): """test the to_lists function""" - list_ztf = test_dataset.nested.nest.to_lists(fields=["t", "flux"]) + nf = nd.datasets.generate_data(10, 100, npartitions=2, seed=1) + list_nf = nf.nested.nest.to_lists(fields=["t", "flux"]) + + assert "band" not in list_nf.columns # check dtypes - assert list_ztf.dtypes["t"] == pd.ArrowDtype(pa.list_(pa.float64())) - assert list_ztf.dtypes["flux"] == pd.ArrowDtype(pa.list_(pa.float64())) + assert list_nf.dtypes["t"] == pd.ArrowDtype(pa.list_(pa.float64())) + assert list_nf.dtypes["flux"] == pd.ArrowDtype(pa.list_(pa.float64())) # Make sure we have a single row for an id - assert len(list_ztf.loc[1]) == 1 + assert len(list_nf.loc[1]) == 1 # Make sure we retain all rows -- double loc for speed and pandas get_item - assert len(list_ztf.loc[1].compute().loc[1]["t"]) == 500 + assert len(list_nf.loc[1].compute().loc[1]["t"]) == 100 + one_row = list_nf.compute().iloc[1] # spot-check values - assert pytest.approx(list_ztf.loc[1].compute().loc[1]["t"][0], 0.01) == 7.5690279 - assert pytest.approx(list_ztf.loc[1].compute().loc[1]["flux"][0], 0.01) == 79.6886 + assert pytest.approx(one_row["t"][0], 0.01) == 19.3652 + assert pytest.approx(one_row["flux"][0], 0.01) == 61.7461 diff --git a/tests/nested_dask/test_datasets.py b/tests/nested_dask/test_datasets.py index 26f9ad7..7a2e66a 100644 --- a/tests/nested_dask/test_datasets.py +++ b/tests/nested_dask/test_datasets.py @@ -1,4 +1,5 @@ import nested_dask as nd +import pytest def test_generate_data(): @@ -18,3 +19,8 @@ def test_generate_data(): # test the length assert len(generate_1) == 10 assert len(generate_1.nested.nest.to_flat()) == 1000 + + # test seed stability + assert pytest.approx(generate_1.compute().loc[0]["a"], 0.1) == 0.417 + assert pytest.approx(generate_1.compute().loc[0]["b"], 0.1) == 0.838 + assert pytest.approx(generate_1.nested.nest.to_flat().compute().iloc[0]["t"], 0.1) == 16.015