Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Set posfilter files to no longer load when --from-cache set #302

Merged
merged 4 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/gnatss/ops/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,37 @@ def get_data_inputs(all_observations: pd.DataFrame) -> NumbaList:
return data_inputs


def prefilter_replies(
all_observations: pd.DataFrame,
num_transponders: int,
) -> pd.DataFrame:
"""
Remove pings that do receive replies from each
transponder in the array.

Parameters
----------
all_observations : pd.DataFrame
The original observations that include every ping and reply
num_transponders : int
The number of transponders in the array

Returns
-------
pd.DataFrame
The observations where the number of replies equal the
number of transponders
"""
# Get value counts for transmit times
time_counts = all_observations[constants.DATA_SPEC.tx_time].value_counts()

return all_observations[
all_observations[constants.DATA_SPEC.tx_time].isin(
time_counts[time_counts == num_transponders].index
)
]


def clean_tt(
travel_times: pd.DataFrame,
transponder_ids: list[str],
Expand Down
4 changes: 2 additions & 2 deletions src/gnatss/ops/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def load_datasets(
# Gather main
all_files_dict.update(gather_files(config, proc="main", mode=mode))

# Gather posfilter
if not skip_posfilter:
# Gather posfilter (Skip if from_cache set)
if not skip_posfilter and not from_cache:
all_files_dict.update(gather_files(config, proc="posfilter", mode=mode))

# Gather solver
Expand Down
12 changes: 8 additions & 4 deletions src/gnatss/solver/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .. import constants
from ..configs.main import Configuration
from ..configs.solver import SolverTransponder
from ..ops.data import filter_tt, get_data_inputs
from ..ops.data import filter_tt, get_data_inputs, prefilter_replies
from ..ops.validate import check_solutions
from ..utilities.geo import _get_rotation_matrix
from ..utilities.time import AstroTime
Expand Down Expand Up @@ -387,15 +387,19 @@ def prepare_and_solve(
# Store original xyz
original_positions = transponders_xyz.copy()

# Store number of transponders
num_transponders = len(transponders)

typer.echo("Preparing data inputs...")
data_inputs = get_data_inputs(all_observations)
typer.echo(f"Pre-filtering data with fewer than {num_transponders} replies...")
reduced_observations = prefilter_replies(all_observations, num_transponders)
data_inputs = get_data_inputs(reduced_observations)

typer.echo("Perform solve...")
is_converged = False
n_iter = 0
num_transponders = len(transponders)
process_dict = {}
num_data = len(all_observations)
num_data = len(reduced_observations)
typer.echo(f"--- {len(data_inputs)} epochs, {num_data} measurements ---")
while not is_converged:
# Max converge attempt failure
Expand Down