diff --git a/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai b/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai
index 67cad75f..3cad8001 100644
--- a/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai
+++ b/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai
@@ -52,7 +52,7 @@
- hmm.most_likely_states([59.7382107943162,3.99285183724331])
+ hmm.infer_state([59.7382107943162,3.99285183724331])
diff --git a/src/Bonsai.ML.HiddenMarkovModels/main.py b/src/Bonsai.ML.HiddenMarkovModels/main.py
index 3f83b932..f9d7d6b8 100644
--- a/src/Bonsai.ML.HiddenMarkovModels/main.py
+++ b/src/Bonsai.ML.HiddenMarkovModels/main.py
@@ -122,15 +122,15 @@ def update_params(self, initial_state_distribution, transitions_params, observat
else:
self.observations_params = (hmm_params[2],)
- def get_predicted_states(self):
- self.predicted_states = np.array([self.infer_state(obs) for obs in self.batch_observations]).astype(int)
-
def infer_state(self, observation: list[float]):
- self.log_alpha = self.compute_log_alpha(
- np.expand_dims(np.array(observation), 0), self.log_alpha)
+ observation = np.expand_dims(np.array(observation), 0)
+ self.log_alpha = self.compute_log_alpha(observation, self.log_alpha)
self.state_probabilities = np.exp(self.log_alpha).astype(np.double)
- return self.state_probabilities.argmax()
+ prediction = self.state_probabilities.argmax()
+ self.predicted_states = np.append(self.predicted_states, prediction)
+ self.batch_observations = np.vstack([self.batch_observations, observation])
+ return prediction
def compute_log_alpha(self, obs, log_alpha=None):
@@ -174,8 +174,6 @@ def fit_async(self,
self.batch = np.vstack(
[self.batch[1:], np.expand_dims(np.array(observation), 0)])
- self.batch_observations = self.batch
-
if not self.is_running and self.loop is None and self.thread is None:
if self.curr_batch_size >= batch_size:
@@ -224,8 +222,6 @@ def on_completion(future):
if self.flush_data_between_batches:
self.batch = None
- self.get_predicted_states()
-
self.is_running = True
if self.loop is None or self.loop.is_closed():