-
-
Notifications
You must be signed in to change notification settings - Fork 555
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
81401ce
commit dfecc51
Showing
5 changed files
with
282 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from __future__ import annotations | ||
|
||
from .efficient_rollingprauc import EfficientRollingPRAUC | ||
|
||
__all__ = ["EfficientRollingPRAUC"] |
150 changes: 150 additions & 0 deletions
150
river/metrics/efficient_rollingprauc/cpp/RollingPRAUC.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
#include "RollingPRAUC.hpp" | ||
|
||
#include <limits> | ||
#include <stdlib.h> | ||
|
||
namespace rollingprauc { | ||
|
||
RollingPRAUC::RollingPRAUC(): positiveLabel{1}, windowSize{1000}, positives{0} { | ||
} | ||
|
||
RollingPRAUC::RollingPRAUC(int positiveLabel, long unsigned windowSize): | ||
positiveLabel{positiveLabel}, windowSize{windowSize}, positives{0} { | ||
} | ||
|
||
void RollingPRAUC::update(int label, double score) { | ||
if (this->window.size() == this->windowSize) | ||
this->removeLast(); | ||
|
||
this->insert(label, score); | ||
|
||
return; | ||
} | ||
|
||
void RollingPRAUC::revert(int label, double score) { | ||
int normalizedLabel = 0; | ||
if (label == this->positiveLabel) | ||
normalizedLabel = 1; | ||
|
||
std::deque<std::tuple<double, int>>::const_iterator it{this->window.cbegin()}; | ||
for (; it != this->window.cend(); ++it) | ||
if (std::get<0>(*it) == score && std::get<1>(*it) == normalizedLabel) | ||
break; | ||
|
||
if (it == this->window.cend()) | ||
return; | ||
|
||
if (normalizedLabel) | ||
this->positives--; | ||
|
||
this->window.erase(it); | ||
|
||
std::multiset<std::tuple<double, int>>::const_iterator itr{ | ||
this->orderedWindow.find(std::make_tuple(score, label)) | ||
}; | ||
this->orderedWindow.erase(itr); | ||
|
||
return; | ||
} | ||
|
||
double RollingPRAUC::get() const { | ||
unsigned long windowSize{this->window.size()}; | ||
|
||
// If there is only one class in the window, it will lead to a | ||
// division by zero. So, zero is returned. | ||
if (!this->positives || !(windowSize - this->positives)) | ||
return 0; | ||
|
||
unsigned long fp{windowSize - this->positives}; | ||
unsigned long tp{this->positives}, tpPrev{tp}; | ||
|
||
double auc{0}, scorePrev{std::numeric_limits<double>::max()}; | ||
|
||
double prec{tp / (double) (tp + fp)}, precPrev{prec}; | ||
|
||
std::multiset<std::tuple<double, int>>::const_iterator it{this->orderedWindow.begin()}; | ||
double score; | ||
int label; | ||
|
||
for (; it != this->orderedWindow.end(); ++it) { | ||
score = std::get<0>(*it); | ||
label = std::get<1>(*it); | ||
|
||
if (score != scorePrev) { | ||
prec = tp / (double) (tp + fp); | ||
|
||
if (precPrev > prec) | ||
prec = precPrev; // Monotonic. decreasing | ||
|
||
auc += this->trapzArea(tp, tpPrev, prec, precPrev); | ||
|
||
scorePrev = score; | ||
tpPrev = tp; | ||
precPrev = prec; | ||
} | ||
|
||
if (label) tp--; | ||
else fp--; | ||
} | ||
|
||
auc += this->trapzArea(tp, tpPrev, 1.0, precPrev); | ||
|
||
return auc / this->positives; // Scale the x axis | ||
} | ||
|
||
void RollingPRAUC::insert(int label, double score) { | ||
// Normalize label to 0 (negative) or 1 (positive) | ||
int l = 0; | ||
if (label == this->positiveLabel) { | ||
l = 1; | ||
this->positives++; | ||
} | ||
|
||
this->window.emplace_back(score, l); | ||
this->orderedWindow.emplace(score, l); | ||
|
||
return; | ||
} | ||
|
||
void RollingPRAUC::removeLast() { | ||
std::tuple<double, int> last{this->window.front()}; | ||
|
||
if (std::get<1>(last)) | ||
this->positives--; | ||
|
||
this->window.pop_front(); | ||
|
||
// Erase using a iterator to avoid multiple erases with equivalent instances | ||
std::multiset<std::tuple<double, int>>::iterator it{ | ||
this->orderedWindow.find(last) | ||
}; | ||
this->orderedWindow.erase(it); | ||
|
||
return; | ||
} | ||
|
||
std::vector<int> RollingPRAUC::getTrueLabels() const { | ||
std::vector<int> trueLabels; | ||
|
||
std::deque<std::tuple<double, int>>::const_iterator it{this->window.begin()}; | ||
for (; it != this->window.end(); ++it) | ||
trueLabels.push_back(std::get<1>(*it)); | ||
|
||
return trueLabels; | ||
} | ||
|
||
std::vector<double> RollingPRAUC::getScores() const { | ||
std::vector<double> scores; | ||
|
||
std::deque<std::tuple<double, int>>::const_iterator it{this->window.begin()}; | ||
for (; it != this->window.end(); ++it) | ||
scores.push_back(std::get<0>(*it)); | ||
|
||
return scores; | ||
} | ||
|
||
double RollingPRAUC::trapzArea(double x1, double x2, double y1, double y2) const { | ||
return abs(x1 - x2) * (y1 + y2) / 2; | ||
} | ||
|
||
} // namespace rollingprauc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#ifndef ROLLINGPRAUC_HPP | ||
#define ROLLINGPRAUC_HPP | ||
|
||
#include <deque> | ||
#include <set> | ||
#include <tuple> | ||
#include <vector> | ||
|
||
namespace rollingprauc { | ||
|
||
class RollingPRAUC { | ||
public: | ||
RollingPRAUC(); | ||
RollingPRAUC(const int positiveLabel, const long unsigned windowSize); | ||
|
||
virtual ~RollingPRAUC() = default; | ||
|
||
// Calls insert() and removeLast() if needed | ||
virtual void update(const int label, const double score); | ||
|
||
// Erase the most recent instance with content equal to params | ||
virtual void revert(const int label, const double score); | ||
|
||
// Calculates the PRAUC and returns it | ||
virtual double get() const; | ||
|
||
// Returns y_true as a vector | ||
virtual std::vector<int> getTrueLabels() const; | ||
|
||
// Returns y_score as a vector | ||
virtual std::vector<double> getScores() const; | ||
|
||
private: | ||
// Insert instance based on params | ||
virtual void insert(const int label, const double score); | ||
|
||
// Remove oldest instance | ||
virtual void removeLast(); | ||
|
||
// Calculates the trapezoid area | ||
double trapzArea(double x1, double x2, double y1, double y2) const; | ||
|
||
int positiveLabel; | ||
|
||
std::size_t windowSize; | ||
std::size_t positives; | ||
|
||
// window maintains a queue of the instances to store the temporal | ||
// aspect of the stream. Using deque to allow revert() | ||
std::deque<std::tuple<double, int>> window; | ||
|
||
// orderedWindow maintains a multiset (implemented as a tree) | ||
// to store the ordered instances | ||
std::multiset<std::tuple<double, int>> orderedWindow; | ||
}; | ||
|
||
} // namespace rollingprauc | ||
|
||
#endif |
13 changes: 13 additions & 0 deletions
13
river/metrics/efficient_rollingprauc/efficient_rollingprauc.pxd
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from libcpp.vector cimport vector | ||
|
||
cdef extern from "cpp/RollingPRAUC.cpp": | ||
pass | ||
|
||
cdef extern from "cpp/RollingPRAUC.hpp" namespace "rollingprauc": | ||
cdef cppclass RollingPRAUC: | ||
RollingPRAUC(int positiveLabel, int windowSize) except + | ||
void update(int label, double score) | ||
void revert(int label, double score) | ||
double get() | ||
vector[int] getTrueLabels() | ||
vector[double] getScores() |
55 changes: 55 additions & 0 deletions
55
river/metrics/efficient_rollingprauc/efficient_rollingprauc.pyx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# distutils: language = c++ | ||
# distutils: extra_compile_args = "-std=c++11" | ||
|
||
import cython | ||
|
||
from .efficient_rollingprauc cimport RollingPRAUC as CppRollingPRAUC | ||
|
||
cdef class EfficientRollingPRAUC: | ||
cdef cython.int positiveLabel | ||
cdef cython.ulong windowSize | ||
cdef CppRollingPRAUC* rollingprauc | ||
|
||
def __cinit__(self, cython.int positiveLabel, cython.ulong windowSize): | ||
self.positiveLabel = positiveLabel | ||
self.windowSize = windowSize | ||
self.rollingprauc = new CppRollingPRAUC(positiveLabel, windowSize) | ||
|
||
def __dealloc__(self): | ||
if not self.rollingprauc == NULL: | ||
del self.rollingprauc | ||
|
||
def update(self, label, score): | ||
self.rollingprauc.update(label, score) | ||
|
||
def revert(self, label, score): | ||
self.rollingprauc.revert(label, score) | ||
|
||
def get(self): | ||
return self.rollingprauc.get() | ||
|
||
def __getnewargs_ex__(self): | ||
# Pickle will use this function to pass the arguments to __new__ | ||
return (self.positiveLabel, self.windowSize),{} | ||
|
||
def __getstate__(self): | ||
""" | ||
On pickling, the true labels and scores of the instances in the | ||
window will be dumped | ||
""" | ||
return (self.rollingprauc.getTrueLabels(), self.rollingprauc.getScores()) | ||
|
||
def __setstate__(self, state): | ||
""" | ||
On unpickling, the state parameter will have the true labels | ||
and scores, this function updates the rollingprauc with them | ||
""" | ||
|
||
# Labels returned by __getstate__ are normalized (0 or 1) | ||
labels, scores = state | ||
|
||
for label, score in zip(labels, scores): | ||
# If label is 1, update with the positive label defined by the constructor | ||
# Else, update with a negative label | ||
l = self.positiveLabel if label else int(not self.positiveLabel) | ||
self.update(l, score) |