Skip to content

Commit

Permalink
Support list of paths for parquet dataset. (#288)
Browse files Browse the repository at this point in the history
  • Loading branch information
delucchi-cmu authored Jun 10, 2024
1 parent b829115 commit a7338ca
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
25 changes: 16 additions & 9 deletions src/hipscat/io/file_io/file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,34 +198,41 @@ def read_parquet_metadata(


def read_parquet_dataset(
dir_pointer: FilePointer, storage_options: Union[Dict[Any, Any], None] = None, **kwargs
) -> Tuple(FilePointer, Dataset):
"""Read parquet dataset from directory pointer.
source: FilePointer, storage_options: Union[Dict[Any, Any], None] = None, **kwargs
) -> Tuple[FilePointer, Dataset]:
"""Read parquet dataset from directory pointer or list of files.
Note that pyarrow.dataset reads require that directory pointers don't contain a
leading slash, and the protocol prefix may additionally be removed. As such, we also return
the directory path that is formatted for pyarrow ingestion for follow-up.
See more info on source specification and possible kwargs at
https://arrow.apache.org/docs/python/generated/pyarrow.dataset.dataset.html
Args:
dir_pointer: location of file to read metadata from
source: directory, path, or list of paths to read data from
storage_options: dictionary that contains abstract filesystem credentials
Returns:
Tuple containing a path to the dataset (that is formatted for pyarrow ingestion)
and the dataset read from disk.
"""
file_system, dir_pointer = get_fs(file_pointer=dir_pointer, storage_options=storage_options)

# pyarrow.dataset requires the pointer not lead with a slash
dir_pointer = strip_leading_slash_for_pyarrow(dir_pointer, file_system.protocol)
if pd.api.types.is_list_like(source) and len(source) > 0:
sample_pointer = source[0]
file_system, sample_pointer = get_fs(file_pointer=sample_pointer, storage_options=storage_options)
source = [strip_leading_slash_for_pyarrow(path, file_system.protocol) for path in source]
else:
file_system, source = get_fs(file_pointer=source, storage_options=storage_options)
source = strip_leading_slash_for_pyarrow(source, file_system.protocol)

dataset = pds.dataset(
dir_pointer,
source,
filesystem=file_system,
format="parquet",
**kwargs,
)
return (dir_pointer, dataset)
return (source, dataset)


def read_parquet_file(file_pointer: FilePointer, storage_options: Union[Dict[Any, Any], None] = None):
Expand Down
22 changes: 22 additions & 0 deletions tests/hipscat/io/file_io/test_file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
load_json_file,
load_parquet_to_pandas,
make_directory,
read_parquet_dataset,
read_parquet_file_to_pandas,
remove_directory,
write_dataframe_to_csv,
Expand Down Expand Up @@ -112,3 +113,24 @@ def test_read_parquet_data(tmp_path):
file_pointer = get_file_pointer_from_path(test_file_path)
dataframe = read_parquet_file_to_pandas(file_pointer)
pd.testing.assert_frame_equal(dataframe, random_df)


def test_read_parquet_dataset(small_sky_dir, small_sky_order1_dir):
(_, ds) = read_parquet_dataset(os.path.join(small_sky_dir, "Norder=0"))

assert ds.count_rows() == 131

(_, ds) = read_parquet_dataset([os.path.join(small_sky_dir, "Norder=0", "Dir=0", "Npix=11.parquet")])

assert ds.count_rows() == 131

(_, ds) = read_parquet_dataset(
[
os.path.join(small_sky_order1_dir, "Norder=1", "Dir=0", "Npix=44.parquet"),
os.path.join(small_sky_order1_dir, "Norder=1", "Dir=0", "Npix=45.parquet"),
os.path.join(small_sky_order1_dir, "Norder=1", "Dir=0", "Npix=46.parquet"),
os.path.join(small_sky_order1_dir, "Norder=1", "Dir=0", "Npix=47.parquet"),
]
)

assert ds.count_rows() == 131

0 comments on commit a7338ca

Please sign in to comment.