Skip to content

Commit

Permalink
Remove typing imports for List, Tuple, Union (#441)
Browse files Browse the repository at this point in the history
* Remove List, Tuple, Union

* oops
  • Loading branch information
delucchi-cmu authored Dec 9, 2024
1 parent e3d899a commit f21c1fb
Show file tree
Hide file tree
Showing 27 changed files with 84 additions and 114 deletions.
15 changes: 7 additions & 8 deletions src/hats/catalog/association_catalog/association_catalog.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

from typing import Union

import pandas as pd
import pyarrow as pa
from mocpy import MOC

from hats.catalog.association_catalog.partition_join_info import PartitionJoinInfo
from hats.catalog.dataset.table_properties import TableProperties
from hats.catalog.healpix_dataset.healpix_dataset import HealpixDataset, PixelInputTypes
from hats.catalog.healpix_dataset.healpix_dataset import HealpixDataset
from hats.catalog.partition_info import PartitionInfo
from hats.pixel_math import HealpixPixel
from hats.pixel_tree.pixel_tree import PixelTree


class AssociationCatalog(HealpixDataset):
Expand All @@ -19,13 +20,11 @@ class AssociationCatalog(HealpixDataset):
Catalog, corresponding to each pair of partitions in each catalog that contain rows to join.
"""

JoinPixelInputTypes = Union[list, pd.DataFrame, PartitionJoinInfo]

def __init__(
self,
catalog_info: TableProperties,
pixels: PixelInputTypes,
join_pixels: JoinPixelInputTypes,
pixels: PartitionInfo | PixelTree | list[HealpixPixel],
join_pixels: list | pd.DataFrame | PartitionJoinInfo,
catalog_path=None,
moc: MOC | None = None,
schema: pa.Schema | None = None,
Expand All @@ -44,7 +43,7 @@ def get_join_pixels(self) -> pd.DataFrame:

@staticmethod
def _get_partition_join_info_from_pixels(
join_pixels: JoinPixelInputTypes,
join_pixels: list | pd.DataFrame | PartitionJoinInfo,
) -> PartitionJoinInfo:
if isinstance(join_pixels, PartitionJoinInfo):
return join_pixels
Expand Down
3 changes: 1 addition & 2 deletions src/hats/catalog/association_catalog/partition_join_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import warnings
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -46,7 +45,7 @@ def _check_column_names(self):
if column not in self.data_frame.columns:
raise ValueError(f"join_info_df does not contain column {column}")

def primary_to_join_map(self) -> Dict[HealpixPixel, List[HealpixPixel]]:
def primary_to_join_map(self) -> dict[HealpixPixel, list[HealpixPixel]]:
"""Generate a map from a single primary pixel to one or more pixels in the join catalog.
Lots of cute comprehension is happening here, so watch out!
Expand Down
4 changes: 1 addition & 3 deletions src/hats/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from __future__ import annotations

from typing import List

from hats.catalog.healpix_dataset.healpix_dataset import HealpixDataset
from hats.pixel_math import HealpixPixel
from hats.pixel_tree.negative_tree import compute_negative_tree_pixels
Expand All @@ -17,7 +15,7 @@ class Catalog(HealpixDataset):
`Norder=/Dir=/Npix=.parquet`
"""

def generate_negative_tree_pixels(self) -> List[HealpixPixel]:
def generate_negative_tree_pixels(self) -> list[HealpixPixel]:
"""Get the leaf nodes at each healpix order that have zero catalog data.
For example, if an example catalog only had data points in pixel 0 at
Expand Down
5 changes: 2 additions & 3 deletions src/hats/catalog/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from pathlib import Path
from typing import List

import pyarrow as pa
from upath import UPath
Expand Down Expand Up @@ -41,8 +40,8 @@ def __init__(
def aggregate_column_statistics(
self,
exclude_hats_columns: bool = True,
exclude_columns: List[str] = None,
include_columns: List[str] = None,
exclude_columns: list[str] = None,
include_columns: list[str] = None,
):
"""Read footer statistics in parquet metadata, and report on global min/max values.
Expand Down
12 changes: 6 additions & 6 deletions src/hats/catalog/dataset/table_properties.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from pathlib import Path
from typing import Iterable, List, Optional, Union
from typing import Iterable, Optional

from jproperties import Properties
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_validator
Expand Down Expand Up @@ -90,7 +90,7 @@ class TableProperties(BaseModel):

ra_column: Optional[str] = Field(default=None, alias="hats_col_ra")
dec_column: Optional[str] = Field(default=None, alias="hats_col_dec")
default_columns: Optional[List[str]] = Field(default=None, alias="hats_cols_default")
default_columns: Optional[list[str]] = Field(default=None, alias="hats_cols_default")
"""Which columns should be read from parquet files, when user doesn't otherwise specify."""

primary_catalog: Optional[str] = Field(default=None, alias="hats_primary_table_url")
Expand Down Expand Up @@ -120,15 +120,15 @@ class TableProperties(BaseModel):
indexing_column: Optional[str] = Field(default=None, alias="hats_index_column")
"""Column that we provide an index over."""

extra_columns: Optional[List[str]] = Field(default=None, alias="hats_index_extra_column")
extra_columns: Optional[list[str]] = Field(default=None, alias="hats_index_extra_column")
"""Any additional payload columns included in index."""

## Allow any extra keyword args to be stored on the properties object.
model_config = ConfigDict(extra="allow", populate_by_name=True, use_enum_values=True)

@field_validator("default_columns", "extra_columns", mode="before")
@classmethod
def space_delimited_list(cls, str_value: str) -> List[str]:
def space_delimited_list(cls, str_value: str) -> list[str]:
"""Convert a space-delimited list string into a python list of strings."""
if isinstance(str_value, str):
# Split on a few kinds of delimiters (just to be safe), and remove duplicates
Expand Down Expand Up @@ -193,7 +193,7 @@ def __str__(self):
return formatted_string

@classmethod
def read_from_dir(cls, catalog_dir: Union[str, Path, UPath]) -> Self:
def read_from_dir(cls, catalog_dir: str | Path | UPath) -> Self:
"""Read field values from a java-style properties file."""
file_path = file_io.get_upath(catalog_dir) / "properties"
if not file_io.does_file_or_directory_exist(file_path):
Expand All @@ -203,7 +203,7 @@ def read_from_dir(cls, catalog_dir: Union[str, Path, UPath]) -> Self:
p.load(f, "utf-8")
return cls(**p.properties)

def to_properties_file(self, catalog_dir: Union[str, Path, UPath]) -> Self:
def to_properties_file(self, catalog_dir: str | Path | UPath) -> Self:
"""Write fields to a java-style properties file."""
# pylint: disable=protected-access
parameters = self.model_dump(by_alias=True, exclude_none=True)
Expand Down
17 changes: 8 additions & 9 deletions src/hats/catalog/healpix_dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from pathlib import Path
from typing import List, Tuple, Union

import astropy.units as u
import numpy as np
Expand Down Expand Up @@ -30,8 +29,6 @@
from hats.pixel_tree.pixel_alignment import align_with_mocs
from hats.pixel_tree.pixel_tree import PixelTree

PixelInputTypes = Union[PartitionInfo, PixelTree, List[HealpixPixel]]


class HealpixDataset(Dataset):
"""A HATS dataset partitioned with a HEALPix partitioning structure.
Expand All @@ -45,7 +42,7 @@ class HealpixDataset(Dataset):
def __init__(
self,
catalog_info: TableProperties,
pixels: PixelInputTypes,
pixels: PartitionInfo | PixelTree | list[HealpixPixel],
catalog_path: str | Path | UPath | None = None,
moc: MOC | None = None,
schema: pa.Schema | None = None,
Expand All @@ -66,7 +63,7 @@ def __init__(
self.pixel_tree = self._get_pixel_tree_from_pixels(pixels)
self.moc = moc

def get_healpix_pixels(self) -> List[HealpixPixel]:
def get_healpix_pixels(self) -> list[HealpixPixel]:
"""Get healpix pixel objects for all pixels contained in the catalog.
Returns:
Expand All @@ -75,7 +72,9 @@ def get_healpix_pixels(self) -> List[HealpixPixel]:
return self.partition_info.get_healpix_pixels()

@staticmethod
def _get_partition_info_from_pixels(pixels: PixelInputTypes) -> PartitionInfo:
def _get_partition_info_from_pixels(
pixels: PartitionInfo | PixelTree | list[HealpixPixel],
) -> PartitionInfo:
if isinstance(pixels, PartitionInfo):
return pixels
if isinstance(pixels, PixelTree):
Expand All @@ -85,7 +84,7 @@ def _get_partition_info_from_pixels(pixels: PixelInputTypes) -> PartitionInfo:
raise TypeError("Pixels must be of type PartitionInfo, PixelTree, or List[HealpixPixel]")

@staticmethod
def _get_pixel_tree_from_pixels(pixels: PixelInputTypes) -> PixelTree:
def _get_pixel_tree_from_pixels(pixels: PartitionInfo | PixelTree | list[HealpixPixel]) -> PixelTree:
if isinstance(pixels, PartitionInfo):
return PixelTree.from_healpix(pixels.get_healpix_pixels())
if isinstance(pixels, PixelTree):
Expand Down Expand Up @@ -118,7 +117,7 @@ def get_max_coverage_order(self) -> int:
)
return max_order

def filter_from_pixel_list(self, pixels: List[HealpixPixel]) -> Self:
def filter_from_pixel_list(self, pixels: list[HealpixPixel]) -> Self:
"""Filter the pixels in the catalog to only include any that overlap with the requested pixels.
Args:
Expand Down Expand Up @@ -155,7 +154,7 @@ def filter_by_cone(self, ra: float, dec: float, radius_arcsec: float) -> Self:
)
return self.filter_by_moc(cone_moc)

def filter_by_box(self, ra: Tuple[float, float], dec: Tuple[float, float]) -> Self:
def filter_by_box(self, ra: tuple[float, float], dec: tuple[float, float]) -> Self:
"""Filter the pixels in the catalog to only include the pixels that overlap with a
zone, defined by right ascension and declination ranges. The right ascension edges follow
great arc circles and the declination edges follow small arc circles.
Expand Down
4 changes: 1 addition & 3 deletions src/hats/catalog/index/index_catalog.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

import numpy as np
import pyarrow.compute as pc
import pyarrow.dataset as pds
Expand All @@ -16,7 +14,7 @@ class IndexCatalog(Dataset):
Note that this is not a true "HATS Catalog", as it is not partitioned spatially.
"""

def loc_partitions(self, ids) -> List[HealpixPixel]:
def loc_partitions(self, ids) -> list[HealpixPixel]:
"""Find the set of partitions in the primary catalog for the ids provided.
Args:
Expand Down
9 changes: 4 additions & 5 deletions src/hats/catalog/partition_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import warnings
from pathlib import Path
from typing import List

import numpy as np
import pandas as pd
Expand All @@ -27,11 +26,11 @@ class PartitionInfo:
METADATA_ORDER_COLUMN_NAME = "Norder"
METADATA_PIXEL_COLUMN_NAME = "Npix"

def __init__(self, pixel_list: List[HealpixPixel], catalog_base_dir: str = None) -> None:
def __init__(self, pixel_list: list[HealpixPixel], catalog_base_dir: str = None) -> None:
self.pixel_list = pixel_list
self.catalog_base_dir = catalog_base_dir

def get_healpix_pixels(self) -> List[HealpixPixel]:
def get_healpix_pixels(self) -> list[HealpixPixel]:
"""Get healpix pixel objects for all pixels represented as partitions.
Returns:
Expand Down Expand Up @@ -158,7 +157,7 @@ def read_from_file(cls, metadata_file: str | Path | UPath, strict: bool = False)
@classmethod
def _read_from_metadata_file(
cls, metadata_file: str | Path | UPath, strict: bool = False
) -> List[HealpixPixel]:
) -> list[HealpixPixel]:
"""Read partition info list from a `_metadata` file.
Args:
Expand Down Expand Up @@ -260,7 +259,7 @@ def as_dataframe(self):
return pd.DataFrame.from_dict(partition_info_dict)

@classmethod
def from_healpix(cls, healpix_pixels: List[HealpixPixel]) -> PartitionInfo:
def from_healpix(cls, healpix_pixels: list[HealpixPixel]) -> PartitionInfo:
"""Create a partition info object from a list of constituent healpix pixels.
Args:
Expand Down
3 changes: 1 addition & 2 deletions src/hats/inspection/almanac.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
import warnings
from typing import List

import pandas as pd

Expand Down Expand Up @@ -218,7 +217,7 @@ def _get_linked_catalog(self, linked_text, namespace) -> AlmanacInfo | None:
return None
return self.entries[resolved_name]

def catalogs(self, include_deprecated=False, types: List[str] | None = None):
def catalogs(self, include_deprecated=False, types: list[str] | None = None):
"""Get names of catalogs in the almanac, matching the provided conditions.
Catalogs must meet all criteria provided in order to be returned (e.g.
Expand Down
17 changes: 8 additions & 9 deletions src/hats/inspection/almanac_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
from dataclasses import dataclass, field
from typing import List

import yaml
from typing_extensions import Self
Expand All @@ -24,14 +23,14 @@ class AlmanacInfo:
join: str | None = None
primary_link: Self | None = None
join_link: Self | None = None
sources: List[Self] = field(default_factory=list)
objects: List[Self] = field(default_factory=list)
margins: List[Self] = field(default_factory=list)
associations: List[Self] = field(default_factory=list)
associations_right: List[Self] = field(default_factory=list)
indexes: List[Self] = field(default_factory=list)

creators: List[str] = field(default_factory=list)
sources: list[Self] = field(default_factory=list)
objects: list[Self] = field(default_factory=list)
margins: list[Self] = field(default_factory=list)
associations: list[Self] = field(default_factory=list)
associations_right: list[Self] = field(default_factory=list)
indexes: list[Self] = field(default_factory=list)

creators: list[str] = field(default_factory=list)
description: str = ""
version: str = ""
deprecated: str = ""
Expand Down
Loading

0 comments on commit f21c1fb

Please sign in to comment.