Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Optuna Sensor Manager #1106

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
python -m venv venv
. venv/bin/activate
pip install --upgrade pip
pip install -e .[dev,orbital] opencv-python-headless pyehm
pip install -e .[dev,ehm,optuna,orbital] opencv-python-headless
- save_cache:
paths:
- ./venv
Expand Down Expand Up @@ -75,7 +75,7 @@ jobs:
python -m venv venv
. venv/bin/activate
pip install --upgrade pip
pip install -e .[orbital] opencv-python-headless plotly pytest-cov pytest-remotedata pytest-skip-slow pyehm confluent-kafka h5py pandas
pip install -e .[ehm,optuna,orbital] opencv-python-headless plotly pytest-cov pytest-remotedata pytest-skip-slow pyehm confluent-kafka h5py pandas
- save_cache:
paths:
- ./venv
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ mfa = [
ehm = [
"pyehm",
]
optuna = [
"optuna",
]

[tool.setuptools]
include-package-data = false
Expand Down
4 changes: 4 additions & 0 deletions stonesoup/sensormanager/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def min(self):
def max(self):
raise NotImplementedError

@abstractmethod
def action_from_value(self):
raise NotImplementedError


class ActionableProperty(Property):
"""Property that is modified via an :class:`~.Action` with defined, non-equal start and end
Expand Down
93 changes: 93 additions & 0 deletions stonesoup/sensormanager/optuna.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Iterable
from collections import defaultdict
import warnings

try:
import optuna
except ImportError as error:
raise ImportError("Usage of Optuna Sensor Manager requires that the optional package "
"`optuna`is installed") from error

from ..base import Property
from ..sensor.sensor import Sensor
from .action import RealNumberActionGenerator, Action
from . import SensorManager


class OptunaSensorManager(SensorManager):
"""Sensor Manager that uses the optuna package to determine the best actions available within
a time frame specified by :attr:`timeout`."""
timeout: float = Property(
doc="Number of seconds that the sensor manager should optimise for each time-step",
default=10.)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
optuna.logging.set_verbosity(optuna.logging.CRITICAL)

def choose_actions(self, tracks, timestamp, nchoose=1, **kwargs) -> Iterable[tuple[Sensor,
Action]]:
"""Method to find the best actions for the given :attr:`sensors` to according to the
:attr:`reward_function`.

Parameters
----------
tracks_list : List[Track]
List of Tracks for the sensor manager to observe.
timestamp: datetime.datetime
The time for the actions to be produced for.

Returns
-------
Iterable[Tuple[Sensor, Action]]
The actions and associated sensors produced by the sensor manager."""
all_action_generators = dict()

for sensor in self.sensors:
action_generators = sensor.actions(timestamp)
all_action_generators[sensor] = action_generators # set of generators

def config_from_trial(trial):
config = defaultdict(list)
for i, (sensor, generators) in enumerate(all_action_generators.items()):

for j, generator in enumerate(generators):
if isinstance(generator, RealNumberActionGenerator):
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
value = trial.suggest_float(
f'{i}{j}', generator.min, generator.max + generator.epsilon,
step=getattr(generator, 'resolution', None))
else:
raise TypeError(f"type {type(generator)} not handled yet")
action = generator.action_from_value(value)
if action is not None:
config[sensor].append(action)
else:
config[sensor].append(generator.default_action)
return config

def optimise_func(trial):
config = config_from_trial(trial)

return -self.reward_function(config, tracks, timestamp)

study = optuna.create_study()
# will finish study after `timeout` seconds has elapsed.
study.optimize(optimise_func, n_trials=None, timeout=self.timeout)

best_params = study.best_params
config = defaultdict(list)
for i, (sensor, generators) in enumerate(all_action_generators.items()):
for j, generator in enumerate(generators):
if isinstance(generator, RealNumberActionGenerator):
action = generator.action_from_value(best_params[f'{i}{j}'])
else:
raise TypeError(f"generator type {type(generator)} not supported")
if action is not None:
config[sensor].append(action)
else:
config[sensor].append(generator.default_action)

# Return mapping of sensors and chosen actions for sensors
return [config]
79 changes: 79 additions & 0 deletions stonesoup/sensormanager/tests/test_optuna.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import copy
from collections import defaultdict
import pytest
from ordered_set import OrderedSet
import numpy as np

try:
from ..optuna import OptunaSensorManager
except ImportError:
# Catch optional dependencies import error
pytest.skip(
"Skipping due to missing optional dependencies. Usage of Optuna Sensor Manager requires "
"that the optional package `optuna`is installed.",
allow_module_level=True
)

from ..reward import UncertaintyRewardFunction
from ...hypothesiser.distance import DistanceHypothesiser
from ...measures import Mahalanobis
from ...dataassociator.neighbour import GNNWith2DAssignment
from ...sensor.radar.radar import RadarRotatingBearingRange
from ...sensor.action.dwell_action import ChangeDwellAction


def test_optuna_manager(params):
predictor = params['predictor']
updater = params['updater']
sensor_set = params['sensor_set']
timesteps = params['timesteps']
tracks = params['tracks']
truths = params['truths']

reward_function = UncertaintyRewardFunction(predictor, updater)
optunasensormanager = OptunaSensorManager(sensor_set, reward_function=reward_function,
timeout=0.1)

hypothesiser = DistanceHypothesiser(predictor, updater, measure=Mahalanobis(),
missed_distance=5)
data_associator = GNNWith2DAssignment(hypothesiser)

sensor_history = defaultdict(dict)
dwell_centres = dict()

for timestep in timesteps[1:]:
chosen_actions = optunasensormanager.choose_actions(tracks, timestep)
measurements = set()
for chosen_action in chosen_actions:
for sensor, actions in chosen_action.items():
sensor.add_actions(actions)
for sensor in sensor_set:
sensor.act(timestep)
sensor_history[timestep][sensor] = copy.copy(sensor)
dwell_centres[timestep] = sensor.dwell_centre[0][0]
measurements |= sensor.measure(OrderedSet(truth[timestep] for truth in truths),
noise=False)
hypotheses = data_associator.associate(tracks,
measurements,
timestep)
for track in tracks:
hypothesis = hypotheses[track]
if hypothesis.measurement:
post = updater.update(hypothesis)
track.append(post)
else:
track.append(hypothesis.prediction)

# Double check choose_actions method types are as expected
assert isinstance(chosen_actions, list)

for chosen_actions in chosen_actions:
for sensor, actions in chosen_action.items():
assert isinstance(sensor, RadarRotatingBearingRange)
assert isinstance(actions[0], ChangeDwellAction)

# Check sensor following track as expected
assert dwell_centres[timesteps[5]] - np.radians(135) < 1e-3
assert dwell_centres[timesteps[15]] - np.radians(45) < 1e-3
assert dwell_centres[timesteps[25]] - np.radians(-45) < 1e-3
assert dwell_centres[timesteps[35]] - np.radians(-135) < 1e-3