diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index fd4216d..54f7840 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -107,10 +107,8 @@ def init_d(self, rating: Tensor) -> Tensor: def next_d(self, state: Tensor, rating: Tensor) -> Tensor: delta_d = - self.w[6] * (rating - 3) - if delta_d > 0: - new_d = state[:, 1] + delta_d * (1 - (state[:, 1]/10)) - if delta_d < 0: - new_d = state[:, 1] + delta_d * (1 - ((11 - state[:, 1])/10)) + test = state[:, 1] > 5.5 + new_d = torch.where(test, state[:, 1] + delta_d * (1 - (state[:, 1]/10)), state[:, 1] + delta_d * (1 - ((11 - state[:, 1])/10))) new_d = self.mean_reversion(self.init_d(4), new_d) return new_d