diff --git a/src/hats_import/catalog/run_import.py b/src/hats_import/catalog/run_import.py index 6db80cd5..e3f8b3a7 100644 --- a/src/hats_import/catalog/run_import.py +++ b/src/hats_import/catalog/run_import.py @@ -6,7 +6,6 @@ import os import pickle -from pathlib import Path import hats.io.file_io as io from hats.catalog import PartitionInfo @@ -19,27 +18,13 @@ from hats_import.catalog.resume_plan import ResumePlan -def _validate_arguments(args): - """ - Verify that the args for run are valid: they exist, are of the appropriate type, - and do not specify an output which is a valid catalog. - - Raises ValueError if they are invalid. - """ +def run(args, client): + """Run catalog creation pipeline.""" if not args: raise ValueError("args is required and should be type ImportArguments") if not isinstance(args, ImportArguments): raise ValueError("args must be type ImportArguments") - potential_path = Path(args.output_path) / args.output_artifact_name - if is_valid_catalog(potential_path): - raise ValueError(f"Output path {potential_path} already contains a valid catalog") - - -def run(args, client): - """Run catalog creation pipeline.""" - _validate_arguments(args) - resume_plan = ResumePlan(import_args=args) pickled_reader_file = os.path.join(resume_plan.tmp_path, "reader.pickle") @@ -137,7 +122,7 @@ def run(args, client): # All done - write out the metadata if resume_plan.should_run_finishing: - with resume_plan.print_progress(total=4, stage_name="Finishing") as step_progress: + with resume_plan.print_progress(total=5, stage_name="Finishing") as step_progress: partition_info = PartitionInfo.from_healpix(resume_plan.get_destination_pixels()) partition_info_file = paths.get_partition_info_pointer(args.catalog_path) partition_info.write_to_file(partition_info_file) @@ -151,12 +136,14 @@ def run(args, client): else: partition_info.write_to_metadata_files(args.catalog_path) step_progress.update(1) + io.write_fits_image(raw_histogram, paths.get_point_map_file_pointer(args.catalog_path)) + step_progress.update(1) catalog_info = args.to_table_properties( total_rows, partition_info.get_highest_order(), partition_info.calculate_fractional_coverage() ) catalog_info.to_properties_file(args.catalog_path) step_progress.update(1) - io.write_fits_image(raw_histogram, paths.get_point_map_file_pointer(args.catalog_path)) - step_progress.update(1) resume_plan.clean_resume_files() step_progress.update(1) + assert is_valid_catalog(args.catalog_path) + step_progress.update(1) diff --git a/src/hats_import/margin_cache/margin_cache.py b/src/hats_import/margin_cache/margin_cache.py index 4beddfbc..ace4af4f 100644 --- a/src/hats_import/margin_cache/margin_cache.py +++ b/src/hats_import/margin_cache/margin_cache.py @@ -1,8 +1,5 @@ -from pathlib import Path - from hats.catalog import PartitionInfo from hats.io import file_io, parquet_metadata, paths -from hats.io.validation import is_valid_catalog import hats_import.margin_cache.margin_cache_map_reduce as mcmr from hats_import.margin_cache.margin_cache_resume_plan import MarginCachePlan @@ -18,11 +15,6 @@ def generate_margin_cache(args, client): args (MarginCacheArguments): A valid `MarginCacheArguments` object. client (dask.distributed.Client): A dask distributed client object. """ - potential_path = Path(args.output_path) / args.output_artifact_name - # Verify that the planned output path is not occupied by a valid catalog - if is_valid_catalog(potential_path): - raise ValueError(f"Output path {potential_path} already contains a valid catalog") - resume_plan = MarginCachePlan(args) original_catalog_metadata = paths.get_common_metadata_pointer(args.input_catalog_path) diff --git a/src/hats_import/runtime_arguments.py b/src/hats_import/runtime_arguments.py index cae7322e..b349a39c 100644 --- a/src/hats_import/runtime_arguments.py +++ b/src/hats_import/runtime_arguments.py @@ -9,6 +9,7 @@ from pathlib import Path from hats.io import file_io +from hats.io.validation import is_valid_catalog from upath import UPath # pylint: disable=too-many-instance-attributes @@ -89,6 +90,8 @@ def _check_arguments(self): raise ValueError("dask_threads_per_worker should be greater than 0") self.catalog_path = file_io.get_upath(self.output_path) / self.output_artifact_name + if is_valid_catalog(self.catalog_path): + raise ValueError(f"Output path {self.catalog_path} already contains a valid catalog") if not self.resume: file_io.remove_directory(self.catalog_path, ignore_errors=True) file_io.make_directory(self.catalog_path, exist_ok=True) diff --git a/tests/hats_import/catalog/test_argument_validation.py b/tests/hats_import/catalog/test_argument_validation.py index b6c827f9..632f545c 100644 --- a/tests/hats_import/catalog/test_argument_validation.py +++ b/tests/hats_import/catalog/test_argument_validation.py @@ -241,3 +241,16 @@ def test_check_healpix_order_range(): check_healpix_order_range("two", "order_field") with pytest.raises(TypeError, match="not supported"): check_healpix_order_range(5, "order_field", upper_bound="ten") + + +def test_no_import_overwrite(small_sky_object_catalog, parquet_shards_dir): + """Runner should refuse to overwrite a valid catalog""" + catalog_dir = small_sky_object_catalog.parent + catalog_name = small_sky_object_catalog.name + with pytest.raises(ValueError, match="already contains a valid catalog"): + ImportArguments( + input_path=parquet_shards_dir, + output_path=catalog_dir, + output_artifact_name=catalog_name, + file_reader="parquet", + ) diff --git a/tests/hats_import/catalog/test_run_import.py b/tests/hats_import/catalog/test_run_import.py index 1e895f6d..3070de72 100644 --- a/tests/hats_import/catalog/test_run_import.py +++ b/tests/hats_import/catalog/test_run_import.py @@ -13,11 +13,9 @@ from hats.pixel_math.sparse_histogram import SparseHistogram import hats_import.catalog.run_import as runner -import hats_import.margin_cache.margin_cache as margin_runner from hats_import.catalog.arguments import ImportArguments from hats_import.catalog.file_readers import CsvReader from hats_import.catalog.resume_plan import ResumePlan -from hats_import.margin_cache.margin_cache_arguments import MarginCacheArguments def test_empty_args(): @@ -33,34 +31,6 @@ def test_bad_args(): runner.run(args, None) -def test_no_import_overwrite(small_sky_object_catalog, parquet_shards_dir): - """Runner should refuse to overwrite a valid catalog""" - catalog_dir = small_sky_object_catalog.parent - catalog_name = small_sky_object_catalog.name - args = ImportArguments( - input_path=parquet_shards_dir, - output_path=catalog_dir, - output_artifact_name=catalog_name, - file_reader="parquet", - ) - with pytest.raises(ValueError, match="already contains a valid catalog"): - runner.run(args, None) - - -def test_no_margin_cache_overwrite(small_sky_object_catalog): - """Runner should refuse to generate margin cache which overwrites valid catalog""" - catalog_dir = small_sky_object_catalog.parent - catalog_name = small_sky_object_catalog.name - args = MarginCacheArguments( - input_catalog_path=small_sky_object_catalog, - output_path=catalog_dir, - margin_threshold=10.0, - output_artifact_name=catalog_name, - ) - with pytest.raises(ValueError, match="already contains a valid catalog"): - margin_runner.generate_margin_cache(args, None) - - @pytest.mark.dask def test_resume_dask_runner( dask_client, diff --git a/tests/hats_import/catalog/test_run_round_trip.py b/tests/hats_import/catalog/test_run_round_trip.py index ea3bb8f4..d86b4156 100644 --- a/tests/hats_import/catalog/test_run_round_trip.py +++ b/tests/hats_import/catalog/test_run_round_trip.py @@ -287,15 +287,14 @@ def test_import_delete_provided_temp_directory( """Test that ALL intermediate files (and temporary base directory) are deleted after successful import, when both `delete_intermediate_parquet_files` and `delete_resume_log_files` are set to True.""" - output_dir = tmp_path_factory.mktemp("small_sky_object_catalog") + output_dir = tmp_path_factory.mktemp("catalogs") # Provided temporary directory, outside `output_dir` temp = tmp_path_factory.mktemp("intermediate_files") - base_intermediate_dir = temp / "small_sky_object_catalog" / "intermediate" # When at least one of the delete flags is set to False we do # not delete the provided temporary base directory. args = ImportArguments( - output_artifact_name="small_sky_object_catalog", + output_artifact_name="keep_log_files", input_path=small_sky_parts_dir, file_reader="csv", output_path=output_dir, @@ -307,10 +306,10 @@ def test_import_delete_provided_temp_directory( delete_resume_log_files=False, ) runner.run(args, dask_client) - assert_stage_level_files_exist(base_intermediate_dir) + assert_stage_level_files_exist(temp / "keep_log_files" / "intermediate") args = ImportArguments( - output_artifact_name="small_sky_object_catalog", + output_artifact_name="keep_parquet_intermediate", input_path=small_sky_parts_dir, file_reader="csv", output_path=output_dir, @@ -323,11 +322,11 @@ def test_import_delete_provided_temp_directory( resume=False, ) runner.run(args, dask_client) - assert_intermediate_parquet_files_exist(base_intermediate_dir) + assert_intermediate_parquet_files_exist(temp / "keep_parquet_intermediate" / "intermediate") # The temporary directory is deleted. args = ImportArguments( - output_artifact_name="small_sky_object_catalog", + output_artifact_name="remove_all_intermediate", input_path=small_sky_parts_dir, file_reader="csv", output_path=output_dir, @@ -340,7 +339,7 @@ def test_import_delete_provided_temp_directory( resume=False, ) runner.run(args, dask_client) - assert not os.path.exists(temp) + assert not os.path.exists(temp / "remove_all_intermediate") def assert_stage_level_files_exist(base_intermediate_dir): diff --git a/tests/hats_import/margin_cache/test_arguments_margin_cache.py b/tests/hats_import/margin_cache/test_arguments_margin_cache.py index 75d20589..cb92d263 100644 --- a/tests/hats_import/margin_cache/test_arguments_margin_cache.py +++ b/tests/hats_import/margin_cache/test_arguments_margin_cache.py @@ -126,3 +126,16 @@ def test_to_table_properties(small_sky_source_catalog, tmp_path): assert catalog_info.total_rows == 10 assert catalog_info.ra_column == "source_ra" assert catalog_info.dec_column == "source_dec" + + +def test_no_margin_cache_overwrite(small_sky_object_catalog): + """Runner should refuse to generate margin cache which overwrites valid catalog""" + catalog_dir = small_sky_object_catalog.parent + catalog_name = small_sky_object_catalog.name + with pytest.raises(ValueError, match="already contains a valid catalog"): + MarginCacheArguments( + input_catalog_path=small_sky_object_catalog, + output_path=catalog_dir, + margin_threshold=10.0, + output_artifact_name=catalog_name, + ) diff --git a/tests/hats_import/test_runtime_arguments.py b/tests/hats_import/test_runtime_arguments.py index 2bb16874..d677737d 100644 --- a/tests/hats_import/test_runtime_arguments.py +++ b/tests/hats_import/test_runtime_arguments.py @@ -123,10 +123,10 @@ def test_dask_args(tmp_path): ) -def test_extra_property_dict(test_data_dir): +def test_extra_property_dict(tmp_path): args = RuntimeArguments( output_artifact_name="small_sky_source_catalog", - output_path=test_data_dir, + output_path=tmp_path, ) properties = args.extra_property_dict() @@ -141,13 +141,13 @@ def test_extra_property_dict(test_data_dir): # Most values are dynamic, but these are some safe assumptions. assert properties["hats_builder"].startswith("hats") assert properties["hats_creation_date"].startswith("20") - assert properties["hats_estsize"] > 1_000 + assert properties["hats_estsize"] >= 0 assert properties["hats_release_date"].startswith("20") assert properties["hats_version"].startswith("v") args = RuntimeArguments( output_artifact_name="small_sky_source_catalog", - output_path=test_data_dir, + output_path=tmp_path, addl_hats_properties={"foo": "bar"}, ) @@ -164,7 +164,7 @@ def test_extra_property_dict(test_data_dir): # Most values are dynamic, but these are some safe assumptions. assert properties["hats_builder"].startswith("hats") assert properties["hats_creation_date"].startswith("20") - assert properties["hats_estsize"] > 1_000 + assert properties["hats_estsize"] >= 0 assert properties["hats_release_date"].startswith("20") assert properties["hats_version"].startswith("v") assert properties["foo"] == "bar"