From b30ba6b22fbe6ebe3f3ff5c18eeffe97f79d7600 Mon Sep 17 00:00:00 2001 From: user1823 <92206575+user1823@users.noreply.github.com> Date: Thu, 17 Oct 2024 17:11:27 +0530 Subject: [PATCH] Update fsrs_optimizer.py --- src/fsrs_optimizer/fsrs_optimizer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index fd4216d..54f7840 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -107,10 +107,8 @@ def init_d(self, rating: Tensor) -> Tensor: def next_d(self, state: Tensor, rating: Tensor) -> Tensor: delta_d = - self.w[6] * (rating - 3) - if delta_d > 0: - new_d = state[:, 1] + delta_d * (1 - (state[:, 1]/10)) - if delta_d < 0: - new_d = state[:, 1] + delta_d * (1 - ((11 - state[:, 1])/10)) + test = state[:, 1] > 5.5 + new_d = torch.where(test, state[:, 1] + delta_d * (1 - (state[:, 1]/10)), state[:, 1] + delta_d * (1 - ((11 - state[:, 1])/10))) new_d = self.mean_reversion(self.init_d(4), new_d) return new_d