Skip to content

Commit

Permalink
ATWIN online-ml#3
Browse files Browse the repository at this point in the history
  • Loading branch information
DuckManGO committed Sep 17, 2023
1 parent 31b90ff commit f5fe253
Show file tree
Hide file tree
Showing 9 changed files with 2,045 additions and 41 deletions.
5 changes: 2 additions & 3 deletions river/drift/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
"""
from __future__ import annotations

from . import binary, datasets
from . import binary, datasets, atwin
from .adwin import ADWIN
from .atwin import ATWIN
from .dummy import DummyDriftDetector
from .kswin import KSWIN
from .page_hinkley import PageHinkley
Expand All @@ -19,8 +18,8 @@
__all__ = [
"binary",
"datasets",
"atwin",
"ADWIN",
"ATWIN",
"DriftRetrainingClassifier",
"DummyDriftDetector",
"KSWIN",
Expand Down
341 changes: 341 additions & 0 deletions river/drift/adwin_nc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,341 @@
from math import fabs, log, pow, sqrt

import numpy as np

from collections import deque
from typing import Deque


class AdaptiveWindowing:
""" The helper class for ADWIN
Parameters
----------
delta
Confidence value.
clock
How often ADWIN should check for change. 1 means every new data point, default is 32. Higher
values speed up processing, but may also lead to increased delay in change detection.
max_buckets
The maximum number of buckets of each size that ADWIN should keep before merging buckets
(default is 5).
min_window_length
The minimum length of each subwindow (default is 5). Lower values may decrease delay in
change detection but may also lead to more false positives.
grace_period
ADWIN does not perform any change detection until at least this many data points have
arrived (default is 10).
"""

def __init__(self, delta=.002, clock=32, max_buckets=5, min_window_length=5, grace_period=10):
self.delta = delta
self.bucket_deque: Deque['Bucket'] = deque([Bucket(max_size=max_buckets)])
self.total = 0.
self.variance = 0.
self.width = 0.
self.n_buckets = 0
self.grace_period = grace_period
self.tick = 0
self.total_width = 0
self.n_detections = 0
self.clock = clock
self.max_n_buckets = 0
self.min_window_length = min_window_length
self.max_buckets = max_buckets

def get_n_detections(self):
return self.n_detections

def get_width(self):
return self.width

def get_total(self):
return self.total

def get_variance(self):
return self.variance

@property
def variance_in_window(self):
return self.variance / self.width

def update(self, value):
"""Update the change detector with a single data point.
Apart from adding the element value to the window, by inserting it in
the correct bucket, it will also update the relevant statistics, in
this case the total sum of all values, the window width and the total
variance.
Parameters
----------
value
Input value
Returns
-------
bool
If True then a change is detected.
"""
return self._update(value)

def _update(self, value):
# Increment window with one element
self._insert_element(value, 0.0)

return self._detect_change()

def _insert_element(self, value, variance):
bucket = self.bucket_deque[0]
bucket.insert_data(value, variance)
self.n_buckets += 1

if self.n_buckets > self.max_n_buckets:
self.max_n_buckets = self.n_buckets

# Update width, variance and total
self.width += 1
incremental_variance = 0.0
if self.width > 1.0:
incremental_variance = (
(self.width - 1.0)
* (value - self.total / (self.width - 1.0))
* (value - self.total / (self.width - 1.0))
/ self.width
)
self.variance += incremental_variance
self.total += value

self._compress_buckets()

@staticmethod
def _calculate_bucket_size(row: int):
return pow(2, row)

def _delete_element(self):
bucket = self.bucket_deque[-1]
n = self._calculate_bucket_size(len(self.bucket_deque) - 1) # length of bucket
u = bucket.get_total_at(0) # total of bucket
mu = u / n # mean of bucket
v = bucket.get_variance_at(0) # variance of bucket

# Update width, total and variance
self.width -= n
self.total -= u
mu_window = self.total / self.width # mean of the window
incremental_variance = (
v + n * self.width * (mu - mu_window) * (mu - mu_window)
/ (n + self.width)
)
self.variance -= incremental_variance

bucket.remove()
self.n_buckets -= 1

if bucket.current_idx == 0:
self.bucket_deque.pop()

return n

def _compress_buckets(self):

bucket = self.bucket_deque[0]
idx = 0
while bucket is not None:
k = bucket.current_idx
# Merge buckets if there are more than max_buckets
if k == self.max_buckets + 1:
try:
next_bucket = self.bucket_deque[idx + 1]
except IndexError:
self.bucket_deque.append(Bucket(max_size=self.max_buckets))
next_bucket = self.bucket_deque[-1]
n1 = self._calculate_bucket_size(idx) # length of bucket 1
n2 = self._calculate_bucket_size(idx) # length of bucket 2
mu1 = bucket.get_total_at(0) / n1 # mean of bucket 1
mu2 = bucket.get_total_at(1) / n2 # mean of bucket 2

# Combine total and variance of adjacent buckets
total12 = bucket.get_total_at(0) + bucket.get_total_at(1)
temp = n1 * n2 * (mu1 - mu2) * (mu1 - mu2) / (n1 + n2)
v12 = bucket.get_variance_at(0) + bucket.get_variance_at(1) + temp
next_bucket.insert_data(total12, v12)
self.n_buckets += 1
bucket.compress(2)

if next_bucket.current_idx <= self.max_buckets:
break
else:
break

try:
bucket = self.bucket_deque[idx + 1]
except IndexError:
bucket = None
idx += 1

def _detect_change(self):
"""Detect concept change.
This function is responsible for analysing different cutting points in
the sliding window, to verify if there is a significant change.
Returns
-------
bint
If True then a change is detected.
Notes
-----
Variance calculation is based on:
Babcock, B., Datar, M., Motwani, R., & O’Callaghan, L. (2003).
Maintaining Variance and k-Medians over Data Stream Windows.
Proceedings of the ACM SIGACT-SIGMOD-SIGART
Symposium on Principles of Database Systems, 22, 234–243.
https://doi.org/10.1145/773153.773176
"""
change_detected = False
exit_flag = False
self.tick += 1

# Reduce window
if (self.tick % self.clock == 0) and (self.width > self.grace_period):
reduce_width = True
while reduce_width:
reduce_width = False
exit_flag = False
n0 = 0.0 # length of window 0
n1 = self.width # length of window 1
u0 = 0.0 # total of window 0
u1 = self.total # total of window 1
v0 = 0 # variance of window 0
v1 = self.variance # variance of window 1

# Evaluate each window cut (W_0, W_1)
for idx in range(len(self.bucket_deque) - 1, -1 , -1):
if exit_flag:
break
bucket = self.bucket_deque[idx]

for k in range(bucket.current_idx - 1):
n2 = self._calculate_bucket_size(idx) # length of window 2
u2 = bucket.get_total_at(k) # total of window 2
# Warning: means are calculated inside the loop to get updated values.
mu2 = u2 / n2 # mean of window 2

if n0 > 0.0:
mu0 = u0 / n0 # mean of window 0
v0 += (
bucket.get_variance_at(k) + n0 * n2
* (mu0 - mu2) * (mu0 - mu2)
/ (n0 + n2)
)

if n1 > 0.0:
mu1 = u1 / n1 # mean of window 1
v1 -= (
bucket.get_variance_at(k) + n1 * n2
* (mu1 - mu2) * (mu1 - mu2)
/ (n1 + n2)
)

# Update window 0 and 1
n0 += self._calculate_bucket_size(idx)
n1 -= self._calculate_bucket_size(idx)
u0 += bucket.get_total_at(k)
u1 -= bucket.get_total_at(k)

if (idx == 0) and (k == bucket.current_idx - 1):
exit_flag = True # We are done
break

# Check if delta_mean < epsilon_cut holds
# Note: Must re-calculate means per updated values
delta_mean = (u0 / n0) - (u1 / n1)
if (
n1 >= self.min_window_length
and n0 >= self.min_window_length
and self._evaluate_cut(n0, n1, delta_mean, self.delta)
):
# Change detected

reduce_width = True
change_detected = True
if self.width > 0:
# Reduce the width of the window
n0 -= self._delete_element()
exit_flag = True # We are done
break

self.total_width += self.width
if change_detected:
self.n_detections += 1

return change_detected

def _evaluate_cut(self, n0, n1,
delta_mean, delta):
delta_prime = log(2 * log(self.width) / delta)
# Use reciprocal of m to avoid extra divisions when calculating epsilon
m_recip = ((1.0 / (n0 - self.min_window_length + 1))
+ (1.0 / (n1 - self.min_window_length + 1)))
epsilon = (sqrt(2 * m_recip * self.variance_in_window * delta_prime)
+ 2 / 3 * delta_prime * m_recip)
return fabs(delta_mean) > epsilon


class Bucket:
""" A bucket class to keep statistics.
A bucket stores the summary structure for a contiguous set of data elements.
In this implementation fixed-size arrays are used for efficiency. The index
of the "current" element is used to simulate the dynamic size of the bucket.
"""

def __init__(self, max_size):
self.max_size = max_size

self.current_idx = 0
self.total_array = np.zeros(self.max_size + 1, dtype=float)
self.variance_array = np.zeros(self.max_size + 1, dtype=float)

def clear_at(self, index):
self.set_total_at(0.0, index)
self.set_variance_at(0.0, index)

def insert_data(self, value, variance):
self.set_total_at(value, self.current_idx)
self.set_variance_at(variance, self.current_idx)
self.current_idx += 1

def remove(self):
self.compress(1)

def compress(self, n_elements):
window_len = len(self.total_array)
# Remove first n_elements by shifting elements to the left
for i in range(n_elements, window_len):
self.total_array[i - n_elements] = self.total_array[i]
self.variance_array[i - n_elements] = self.variance_array[i]
# Clear remaining elements
for i in range(window_len - n_elements, window_len):
self.clear_at(i)

self.current_idx -= n_elements

def get_total_at(self, index):
return self.total_array[index]

def get_variance_at(self, index):
return self.variance_array[index]

def set_total_at(self, value, index):
self.total_array[index] = value

def set_variance_at(self, value, index):
self.variance_array[index] = value
7 changes: 7 additions & 0 deletions river/drift/atwin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Test ATWIN Drift Detectors."""
from __future__ import annotations

from .atwin import ATWIN
from .atwin_2 import ATWIN2

__all__ = ["ATWIN", "ATWIN2"]
Loading

0 comments on commit f5fe253

Please sign in to comment.