Skip to content

Commit

Permalink
Add normalisation post constraint function in BernoulliParticleUpdater
Browse files Browse the repository at this point in the history
  • Loading branch information
sdhiscocks committed Apr 9, 2024
1 parent 06b82cd commit 068adf6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
3 changes: 3 additions & 0 deletions stonesoup/updater/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,9 @@ def update(self, hypotheses, **kwargs):
if self.constraint_func is not None:
part_indx = self.constraint_func(updated_state)
updated_state.log_weight[part_indx] = -1*np.inf
if not any(hypotheses):
updated_state.log_weight = copy.copy(updated_state.log_weight)

Check warning on line 533 in stonesoup/updater/particle.py

View check run for this annotation

Codecov / codecov/patch

stonesoup/updater/particle.py#L533

Added line #L533 was not covered by tests
updated_state.log_weight -= logsumexp(updated_state.log_weight)

# Resampling
if self.resampler is not None:
Expand Down
11 changes: 9 additions & 2 deletions stonesoup/updater/tests/test_particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def dummy_constraint_function(particles):
return part_indx


@pytest.fixture(params=[True, False])
def constraint_func(request):
if request.param:
return dummy_constraint_function


@pytest.fixture(params=(
ParticleUpdater,
partial(ParticleUpdater, resampler=SystematicResampler()),
Expand Down Expand Up @@ -102,7 +108,7 @@ def test_particle(updater):
assert np.allclose(updated_state.mean, StateVectors([[15.0], [20.0]]), rtol=5e-2)


def test_bernoulli_particle():
def test_bernoulli_particle(constraint_func):
timestamp = datetime.datetime.now()
timediff = 2
new_timestamp = timestamp + datetime.timedelta(seconds=timediff)
Expand Down Expand Up @@ -173,7 +179,8 @@ def test_bernoulli_particle():
clutter_rate=2,
clutter_distribution=1/10,
nsurv_particles=9,
detection_probability=detection_probability)
detection_probability=detection_probability,
constraint_func=constraint_func)

hypotheses = MultipleHypothesis(
[SingleHypothesis(prediction, detection) for detection in detections])
Expand Down

0 comments on commit 068adf6

Please sign in to comment.