diff --git a/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai b/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai index 67cad75f..3cad8001 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai +++ b/src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai @@ -52,7 +52,7 @@ - hmm.most_likely_states([59.7382107943162,3.99285183724331]) + hmm.infer_state([59.7382107943162,3.99285183724331]) diff --git a/src/Bonsai.ML.HiddenMarkovModels/main.py b/src/Bonsai.ML.HiddenMarkovModels/main.py index 3f83b932..f9d7d6b8 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/main.py +++ b/src/Bonsai.ML.HiddenMarkovModels/main.py @@ -122,15 +122,15 @@ def update_params(self, initial_state_distribution, transitions_params, observat else: self.observations_params = (hmm_params[2],) - def get_predicted_states(self): - self.predicted_states = np.array([self.infer_state(obs) for obs in self.batch_observations]).astype(int) - def infer_state(self, observation: list[float]): - self.log_alpha = self.compute_log_alpha( - np.expand_dims(np.array(observation), 0), self.log_alpha) + observation = np.expand_dims(np.array(observation), 0) + self.log_alpha = self.compute_log_alpha(observation, self.log_alpha) self.state_probabilities = np.exp(self.log_alpha).astype(np.double) - return self.state_probabilities.argmax() + prediction = self.state_probabilities.argmax() + self.predicted_states = np.append(self.predicted_states, prediction) + self.batch_observations = np.vstack([self.batch_observations, observation]) + return prediction def compute_log_alpha(self, obs, log_alpha=None): @@ -174,8 +174,6 @@ def fit_async(self, self.batch = np.vstack( [self.batch[1:], np.expand_dims(np.array(observation), 0)]) - self.batch_observations = self.batch - if not self.is_running and self.loop is None and self.thread is None: if self.curr_batch_size >= batch_size: @@ -224,8 +222,6 @@ def on_completion(future): if self.flush_data_between_batches: self.batch = None - self.get_predicted_states() - self.is_running = True if self.loop is None or self.loop.is_closed():