diff --git a/src/hipscat/io/file_io/file_io.py b/src/hipscat/io/file_io/file_io.py index d465fc7b..d3357dfd 100644 --- a/src/hipscat/io/file_io/file_io.py +++ b/src/hipscat/io/file_io/file_io.py @@ -198,34 +198,41 @@ def read_parquet_metadata( def read_parquet_dataset( - dir_pointer: FilePointer, storage_options: Union[Dict[Any, Any], None] = None, **kwargs -) -> Tuple(FilePointer, Dataset): - """Read parquet dataset from directory pointer. + source: FilePointer, storage_options: Union[Dict[Any, Any], None] = None, **kwargs +) -> Tuple[FilePointer, Dataset]: + """Read parquet dataset from directory pointer or list of files. Note that pyarrow.dataset reads require that directory pointers don't contain a leading slash, and the protocol prefix may additionally be removed. As such, we also return the directory path that is formatted for pyarrow ingestion for follow-up. + See more info on source specification and possible kwargs at + https://arrow.apache.org/docs/python/generated/pyarrow.dataset.dataset.html + Args: - dir_pointer: location of file to read metadata from + source: directory, path, or list of paths to read data from storage_options: dictionary that contains abstract filesystem credentials Returns: Tuple containing a path to the dataset (that is formatted for pyarrow ingestion) and the dataset read from disk. """ - file_system, dir_pointer = get_fs(file_pointer=dir_pointer, storage_options=storage_options) - # pyarrow.dataset requires the pointer not lead with a slash - dir_pointer = strip_leading_slash_for_pyarrow(dir_pointer, file_system.protocol) + if pd.api.types.is_list_like(source) and len(source) > 0: + sample_pointer = source[0] + file_system, sample_pointer = get_fs(file_pointer=sample_pointer, storage_options=storage_options) + source = [strip_leading_slash_for_pyarrow(path, file_system.protocol) for path in source] + else: + file_system, source = get_fs(file_pointer=source, storage_options=storage_options) + source = strip_leading_slash_for_pyarrow(source, file_system.protocol) dataset = pds.dataset( - dir_pointer, + source, filesystem=file_system, format="parquet", **kwargs, ) - return (dir_pointer, dataset) + return (source, dataset) def read_parquet_file(file_pointer: FilePointer, storage_options: Union[Dict[Any, Any], None] = None): diff --git a/tests/hipscat/io/file_io/test_file_io.py b/tests/hipscat/io/file_io/test_file_io.py index ab33beae..8c69439b 100644 --- a/tests/hipscat/io/file_io/test_file_io.py +++ b/tests/hipscat/io/file_io/test_file_io.py @@ -10,6 +10,7 @@ load_json_file, load_parquet_to_pandas, make_directory, + read_parquet_dataset, read_parquet_file_to_pandas, remove_directory, write_dataframe_to_csv, @@ -112,3 +113,24 @@ def test_read_parquet_data(tmp_path): file_pointer = get_file_pointer_from_path(test_file_path) dataframe = read_parquet_file_to_pandas(file_pointer) pd.testing.assert_frame_equal(dataframe, random_df) + + +def test_read_parquet_dataset(small_sky_dir, small_sky_order1_dir): + (_, ds) = read_parquet_dataset(os.path.join(small_sky_dir, "Norder=0")) + + assert ds.count_rows() == 131 + + (_, ds) = read_parquet_dataset([os.path.join(small_sky_dir, "Norder=0", "Dir=0", "Npix=11.parquet")]) + + assert ds.count_rows() == 131 + + (_, ds) = read_parquet_dataset( + [ + os.path.join(small_sky_order1_dir, "Norder=1", "Dir=0", "Npix=44.parquet"), + os.path.join(small_sky_order1_dir, "Norder=1", "Dir=0", "Npix=45.parquet"), + os.path.join(small_sky_order1_dir, "Norder=1", "Dir=0", "Npix=46.parquet"), + os.path.join(small_sky_order1_dir, "Norder=1", "Dir=0", "Npix=47.parquet"), + ] + ) + + assert ds.count_rows() == 131