Skip to content

Commit

Permalink
typing
Browse files Browse the repository at this point in the history
  • Loading branch information
VeckoTheGecko committed Dec 5, 2024
1 parent 132183e commit d8713c8
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions parcels/field.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import math
import warnings
from collections.abc import Iterable
from ctypes import POINTER, Structure, c_float, c_int, pointer
from glob import glob
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Literal

import dask.array as da
import numpy as np
Expand Down Expand Up @@ -50,6 +52,8 @@

from parcels.fieldset import FieldSet

T_SanitizedFilenames = list[str] | dict[str, list[str]]

__all__ = ["Field", "NestedField", "VectorField"]


Expand Down Expand Up @@ -484,7 +488,7 @@ def from_netcdf(
time_periodic: TimePeriodic = False,
deferred_load: bool = True,
**kwargs,
) -> "Field":
) -> Field:
"""Create field from netCDF file.
Parameters
Expand Down Expand Up @@ -587,14 +591,14 @@ def from_netcdf(
), "The variable tuple must have length 2. Use FieldSet.from_netcdf() for multiple variables"

data_filenames = _get_dim_filenames(filenames, "data")
lonlat_filename = _get_dim_filenames(filenames, "lon")
lonlat_filename_lst = _get_dim_filenames(filenames, "lon")
if isinstance(filenames, dict):
assert len(lonlat_filename) == 1
if lonlat_filename != _get_dim_filenames(filenames, "lat"):
assert len(lonlat_filename_lst) == 1
if lonlat_filename_lst != _get_dim_filenames(filenames, "lat"):
raise NotImplementedError(
"longitude and latitude dimensions are currently processed together from one single file"
)
lonlat_filename = lonlat_filename[0]
lonlat_filename = lonlat_filename_lst[0]
if "depth" in dimensions:
depth_filename = _get_dim_filenames(filenames, "depth")
if isinstance(filenames, dict) and len(depth_filename) != 1:
Expand Down Expand Up @@ -2574,7 +2578,7 @@ def __getitem__(self, key):
return val


def _get_dim_filenames(filenames: str | Path | Any | dict[str, str | Any], dim: str) -> Any:
def _get_dim_filenames(filenames: T_SanitizedFilenames, dim: str) -> list[str]:
"""Get's the relevant filenames for a given dimension."""
if isinstance(filenames, list):
return filenames
Expand All @@ -2585,7 +2589,7 @@ def _get_dim_filenames(filenames: str | Path | Any | dict[str, str | Any], dim:
raise ValueError("Filenames must be a string, pathlib.Path, or a dictionary")


def _sanitize_field_filenames(filenames, *, recursed=False):
def _sanitize_field_filenames(filenames, *, recursed=False) -> T_SanitizedFilenames:
"""The Field initializer can take `filenames` to be of various formats including:
1. a string or Path object. String can be a glob expression.
Expand Down

0 comments on commit d8713c8

Please sign in to comment.