Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend multiple data drift univariate methods to multivariate #353

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Extend BhattacharyyaDistance to multivariate
  • Loading branch information
jaime-cespedes-sisniega committed Dec 1, 2024
commit 91c35c53897b048af501176acd07fa5ae7b944b4
43 changes: 32 additions & 11 deletions frouros/detectors/data_drift/batch/distance_based/base.py
Original file line number Diff line number Diff line change
@@ -120,6 +120,7 @@ class BaseDistanceBasedBins(BaseDistanceBased):

def __init__(
self,
statistical_type: BaseStatisticalType,
statistical_method: Callable, # type: ignore
statistical_kwargs: dict[str, Any],
callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]] = None,
@@ -137,9 +138,12 @@ def __init__(
:type num_bins: int
"""
super().__init__(
statistical_type=UnivariateData(),
statistical_type=statistical_type,
statistical_method=statistical_method,
statistical_kwargs={**statistical_kwargs, "num_bins": num_bins},
statistical_kwargs={
**statistical_kwargs,
"num_bins": num_bins,
},
callbacks=callbacks,
)
self.num_bins = num_bins
@@ -171,23 +175,40 @@ def _distance_measure(
X: np.ndarray, # noqa: N803
**kwargs: Any,
) -> DistanceResult:
distance_bins = self._distance_measure_bins(X_ref=X_ref, X=X)
distance = DistanceResult(distance=distance_bins)
distance_bins = self._distance_measure_bins(
X_ref=X_ref,
X=X,
)
distance = DistanceResult(
distance=distance_bins,
)
return distance

@staticmethod
def _calculate_bins_values(
X_ref: np.ndarray, # noqa: N803
X: np.ndarray,
num_bins: int = 10,
) -> np.ndarray:
bins = np.histogram(np.hstack((X_ref, X)), bins=num_bins)[ # get the bin edges
1
) -> Tuple[np.ndarray, np.ndarray]:
# Add a new axis if X_ref and X are 1D
if X_ref.ndim == 1:
X_ref = X_ref[:, np.newaxis]
X = X[:, np.newaxis]

min_edge = np.min(np.vstack((X_ref, X)), axis=0)
max_edge = np.max(np.vstack((X_ref, X)), axis=0)
bins = [
np.linspace(min_edge[i], max_edge[i], num_bins + 1)
for i in range(X_ref.shape[1])
]
X_ref_percents = ( # noqa: N806
np.histogram(a=X_ref, bins=bins)[0] / X_ref.shape[0]
) # noqa: N806
X_percents = np.histogram(a=X, bins=bins)[0] / X.shape[0] # noqa: N806

X_ref_hist, _ = np.histogramdd(X_ref, bins=bins)
X_hist, _ = np.histogramdd(X, bins=bins)

# Normalize histograms
X_ref_percents = X_ref_hist / X_ref.shape[0]
X_percents = X_hist / X.shape[0]

return X_ref_percents, X_percents

@abc.abstractmethod
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
import numpy as np

from frouros.callbacks.batch.base import BaseCallbackBatch
from frouros.detectors.data_drift.base import MultivariateData
from frouros.detectors.data_drift.batch.distance_based.base import (
BaseDistanceBasedBins,
)
@@ -13,7 +14,8 @@
class BhattacharyyaDistance(BaseDistanceBasedBins):
"""Bhattacharyya distance [bhattacharyya1946measure]_ detector.

:param num_bins: number of bins in which to divide probabilities, defaults to 10
:param num_bins: number of bins per dimension in which to
divide probabilities, defaults to 10
:type num_bins: int
:param callbacks: callbacks, defaults to None
:type callbacks: Optional[Union[BaseCallback, list[Callback]]]
@@ -29,12 +31,12 @@ class BhattacharyyaDistance(BaseDistanceBasedBins):
>>> from frouros.detectors.data_drift import BhattacharyyaDistance
>>> import numpy as np
>>> np.random.seed(seed=31)
>>> X = np.random.normal(loc=0, scale=1, size=100)
>>> Y = np.random.normal(loc=1, scale=1, size=100)
>>> detector = BhattacharyyaDistance(num_bins=20)
>>> X = np.random.multivariate_normal(mean=[1, 1], cov=[[2, 0], [0, 2]], size=100)
>>> Y = np.random.multivariate_normal(mean=[0, 0], cov=[[2, 1], [1, 2]], size=100)
>>> detector = BhattacharyyaDistance(num_bins=10)
>>> _ = detector.fit(X=X)
>>> detector.compare(X=Y)
DistanceResult(distance=0.2182101059622703)
DistanceResult(distance=0.3413868461814531)
"""

def __init__( # noqa: D107
@@ -43,6 +45,7 @@ def __init__( # noqa: D107
callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]] = None,
) -> None:
super().__init__(
statistical_type=MultivariateData(),
statistical_method=self._bhattacharyya,
statistical_kwargs={
"num_bins": num_bins,
@@ -56,7 +59,11 @@ def _distance_measure_bins(
X_ref: np.ndarray, # noqa: N803
X: np.ndarray, # noqa: N803
) -> float:
bhattacharyya = self._bhattacharyya(X=X_ref, Y=X, num_bins=self.num_bins)
bhattacharyya = self._bhattacharyya(
X=X_ref,
Y=X,
num_bins=self.num_bins,
)
return bhattacharyya

@staticmethod
@@ -70,7 +77,23 @@ def _bhattacharyya(
X_percents,
Y_percents,
) = BaseDistanceBasedBins._calculate_bins_values(
X_ref=X, X=Y, num_bins=num_bins
X_ref=X,
X=Y,
num_bins=num_bins,
)
bhattacharyya = 1 - np.sum(np.sqrt(X_percents * Y_percents))

# Add small epsilon to avoid log(0)
epsilon = np.finfo(float).eps
X_percents = X_percents + epsilon
Y_percents = Y_percents + epsilon

# Compute Bhattacharyya coefficient
bc = np.sum(np.sqrt(X_percents * Y_percents))
# Clip between [0,1] to avoid numerical errors
bc = np.clip(bc, a_min=0, a_max=1)

# Compute Bhattacharyya distance
# Use absolute value to avoid negative zero values
bhattacharyya = np.abs(-np.log(bc))

return bhattacharyya
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
import numpy as np

from frouros.callbacks.batch.base import BaseCallbackBatch
from frouros.detectors.data_drift.base import UnivariateData
from frouros.detectors.data_drift.batch.distance_based.base import (
BaseDistanceBasedBins,
)
@@ -45,6 +46,7 @@ def __init__( # noqa: D107
) -> None:
sqrt_div = np.sqrt(2)
super().__init__(
statistical_type=UnivariateData(),
statistical_method=self._hellinger,
statistical_kwargs={
"num_bins": num_bins,
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
import numpy as np

from frouros.callbacks.batch.base import BaseCallbackBatch
from frouros.detectors.data_drift.base import UnivariateData
from frouros.detectors.data_drift.batch.distance_based.base import (
BaseDistanceBasedBins,
)
@@ -43,6 +44,7 @@ def __init__( # noqa: D107
callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]] = None,
) -> None:
super().__init__(
statistical_type=UnivariateData(),
statistical_method=self._hi_normalized_complement,
statistical_kwargs={
"num_bins": num_bins,
2 changes: 2 additions & 0 deletions frouros/detectors/data_drift/batch/distance_based/psi.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
import numpy as np

from frouros.callbacks.batch.base import BaseCallbackBatch
from frouros.detectors.data_drift.base import UnivariateData
from frouros.detectors.data_drift.batch.distance_based.base import (
BaseDistanceBasedBins,
DistanceResult,
@@ -45,6 +46,7 @@ def __init__( # noqa: D107
callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]] = None,
) -> None:
super().__init__(
statistical_type=UnivariateData(),
statistical_method=self._psi,
statistical_kwargs={
"num_bins": num_bins,
2 changes: 1 addition & 1 deletion frouros/tests/integration/test_callback.py
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@
@pytest.mark.parametrize(
"detector_class, expected_distance, expected_p_value",
[
(BhattacharyyaDistance, 0.55516059, 0.0),
(BhattacharyyaDistance, 0.81004188, 0.0),
(EMD, 3.85346006, 0.0),
(EnergyDistance, 2.11059982, 0.0),
(HellingerDistance, 0.74509099, 0.0),
34 changes: 23 additions & 11 deletions frouros/tests/integration/test_data_drift.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""Test data drift detectors."""

from typing import Any, Tuple, Union
from typing import (
Any,
Tuple,
Union,
)

import numpy as np
import pytest
@@ -26,12 +30,8 @@
WelchTTest,
)
from frouros.detectors.data_drift.batch.base import BaseDataDriftBatch
from frouros.detectors.data_drift.streaming import (
MMD as MMDStreaming,
)
from frouros.detectors.data_drift.streaming import ( # noqa: N811
IncrementalKSTest,
)
from frouros.detectors.data_drift.streaming import MMD as MMDStreaming
from frouros.detectors.data_drift.streaming import IncrementalKSTest


@pytest.mark.parametrize(
@@ -102,7 +102,7 @@ def test_batch_distance_based_univariate(
[
(PSI(), 461.20379435),
(HellingerDistance(), 0.74509099),
(BhattacharyyaDistance(), 0.55516059),
(BhattacharyyaDistance(), 0.810041883),
],
)
def test_batch_distance_bins_based_univariate_different_distribution(
@@ -133,7 +133,7 @@ def test_batch_distance_bins_based_univariate_different_distribution(
[
(PSI(), 0.01840072),
(HellingerDistance(), 0.04792538),
(BhattacharyyaDistance(), 0.00229684),
(BhattacharyyaDistance(), 0.00229948),
],
)
def test_batch_distance_bins_based_univariate_same_distribution(
@@ -214,7 +214,13 @@ def test_batch_statistical_univariate(
assert np.isclose(p_value, expected_p_value)


@pytest.mark.parametrize("detector, expected_distance", [(MMD(), 0.10163633)])
@pytest.mark.parametrize(
"detector, expected_distance",
[
(BhattacharyyaDistance(), 0.39327743),
(MMD(), 0.10163633),
],
)
def test_batch_distance_based_multivariate_different_distribution(
X_ref_multivariate: np.ndarray, # noqa: N803
X_test_multivariate: np.ndarray, # noqa: N803
@@ -238,7 +244,13 @@ def test_batch_distance_based_multivariate_different_distribution(
assert np.isclose(statistic, expected_distance)


@pytest.mark.parametrize("detector, expected_distance", [(MMD(), 0.01570397)])
@pytest.mark.parametrize(
"detector, expected_distance",
[
(BhattacharyyaDistance(), 0.39772951),
(MMD(), 0.01570397),
],
)
def test_batch_distance_based_multivariate_same_distribution(
multivariate_distribution_p: Tuple[np.ndarray, np.ndarray],
detector: BaseDataDriftBatch,
Loading