Skip to content

Commit

Permalink
broaden seasonality metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
xjing76 committed Jun 29, 2020
1 parent b8c47d9 commit decf50c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
15 changes: 10 additions & 5 deletions tests/test_sampling_seasonality.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
import numpy as np
from datetime import datetime

# %%
import random


def test_seasonality_sampling(N: int = 200, off_param=1):
random.seed(123)
kwargs = gen_defualt_params_seaonality(N)
simulation, _ = simulate_poiszero_hmm(**kwargs)

Expand Down Expand Up @@ -49,7 +53,7 @@ def test_seasonality_sampling(N: int = 200, off_param=1):
[
pm.Constant.dist(0),
pm.Poisson.dist(E_1_mu * seasonal),
pm.Poisson.dist((E_1_mu + E_2_mu) * seasonal),
pm.Poisson.dist((E_2_mu) * seasonal),
],
S_rv,
observed=y_test,
Expand All @@ -66,13 +70,14 @@ def test_seasonality_sampling(N: int = 200, off_param=1):

st_trace = trace_.posterior["S_t"].mean(axis=0).mean(axis=0)
mean_error_rate = (
1 - np.sum(np.equal(st_trace, simulation["S_t"]) * 1) / len(simulation["S_t"])
1
- np.sum(np.equal(st_trace == 0, simulation["S_t"] == 0) * 1)
/ len(simulation["S_t"])
).values.tolist()

positive_index = simulation["Y_t"] > 0
positive_sim = simulation["Y_t"][positive_index]
MAPE = np.nanmean(abs(y_trace[positive_index] - positive_sim) / positive_sim)

assert mean_error_rate < 0.05
assert MAPE < 0.05
return trace_, time_elapsed, test_model, simulation
assert MAPE < 0.3
return trace_, time_elapsed, test_model, simulation, kwargs, y_trace
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def rotate(l, n):
return np.array(l[n:] + l[:n])

week_effect = np.sort(np.random.gamma(shape=1, scale=1, size=7))
day_effect = np.sort(np.random.gamma(shape=1, scale=0.5, size=24))
day_effect = np.sort(np.random.gamma(shape=1, scale=1, size=24))
day_effect = rotate(day_effect, 2)
week_effect = rotate(week_effect, 1)

Expand Down

0 comments on commit decf50c

Please sign in to comment.