From 068adf662b537c9cd5fac0304eae2292f6ffd797 Mon Sep 17 00:00:00 2001 From: Steven Hiscocks Date: Tue, 9 Apr 2024 12:37:32 +0100 Subject: [PATCH] Add normalisation post constraint function in BernoulliParticleUpdater --- stonesoup/updater/particle.py | 3 +++ stonesoup/updater/tests/test_particle.py | 11 +++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/stonesoup/updater/particle.py b/stonesoup/updater/particle.py index 2b9708175..8f8aeb665 100644 --- a/stonesoup/updater/particle.py +++ b/stonesoup/updater/particle.py @@ -529,6 +529,9 @@ def update(self, hypotheses, **kwargs): if self.constraint_func is not None: part_indx = self.constraint_func(updated_state) updated_state.log_weight[part_indx] = -1*np.inf + if not any(hypotheses): + updated_state.log_weight = copy.copy(updated_state.log_weight) + updated_state.log_weight -= logsumexp(updated_state.log_weight) # Resampling if self.resampler is not None: diff --git a/stonesoup/updater/tests/test_particle.py b/stonesoup/updater/tests/test_particle.py index d22d1d7aa..d88d42448 100644 --- a/stonesoup/updater/tests/test_particle.py +++ b/stonesoup/updater/tests/test_particle.py @@ -35,6 +35,12 @@ def dummy_constraint_function(particles): return part_indx +@pytest.fixture(params=[True, False]) +def constraint_func(request): + if request.param: + return dummy_constraint_function + + @pytest.fixture(params=( ParticleUpdater, partial(ParticleUpdater, resampler=SystematicResampler()), @@ -102,7 +108,7 @@ def test_particle(updater): assert np.allclose(updated_state.mean, StateVectors([[15.0], [20.0]]), rtol=5e-2) -def test_bernoulli_particle(): +def test_bernoulli_particle(constraint_func): timestamp = datetime.datetime.now() timediff = 2 new_timestamp = timestamp + datetime.timedelta(seconds=timediff) @@ -173,7 +179,8 @@ def test_bernoulli_particle(): clutter_rate=2, clutter_distribution=1/10, nsurv_particles=9, - detection_probability=detection_probability) + detection_probability=detection_probability, + constraint_func=constraint_func) hypotheses = MultipleHypothesis( [SingleHypothesis(prediction, detection) for detection in detections])