Skip to content

Commit

Permalink
Add Efficient RollingPRAUC
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlpgomes committed May 10, 2024
1 parent 81401ce commit dfecc51
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 0 deletions.
5 changes: 5 additions & 0 deletions river/metrics/efficient_rollingprauc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from .efficient_rollingprauc import EfficientRollingPRAUC

__all__ = ["EfficientRollingPRAUC"]
150 changes: 150 additions & 0 deletions river/metrics/efficient_rollingprauc/cpp/RollingPRAUC.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#include "RollingPRAUC.hpp"

#include <limits>
#include <stdlib.h>

namespace rollingprauc {

RollingPRAUC::RollingPRAUC(): positiveLabel{1}, windowSize{1000}, positives{0} {
}

RollingPRAUC::RollingPRAUC(int positiveLabel, long unsigned windowSize):
positiveLabel{positiveLabel}, windowSize{windowSize}, positives{0} {
}

void RollingPRAUC::update(int label, double score) {
if (this->window.size() == this->windowSize)
this->removeLast();

this->insert(label, score);

return;
}

void RollingPRAUC::revert(int label, double score) {
int normalizedLabel = 0;
if (label == this->positiveLabel)
normalizedLabel = 1;

std::deque<std::tuple<double, int>>::const_iterator it{this->window.cbegin()};
for (; it != this->window.cend(); ++it)
if (std::get<0>(*it) == score && std::get<1>(*it) == normalizedLabel)
break;

if (it == this->window.cend())
return;

if (normalizedLabel)
this->positives--;

this->window.erase(it);

std::multiset<std::tuple<double, int>>::const_iterator itr{
this->orderedWindow.find(std::make_tuple(score, label))
};
this->orderedWindow.erase(itr);

return;
}

double RollingPRAUC::get() const {
unsigned long windowSize{this->window.size()};

// If there is only one class in the window, it will lead to a
// division by zero. So, zero is returned.
if (!this->positives || !(windowSize - this->positives))
return 0;

unsigned long fp{windowSize - this->positives};
unsigned long tp{this->positives}, tpPrev{tp};

double auc{0}, scorePrev{std::numeric_limits<double>::max()};

double prec{tp / (double) (tp + fp)}, precPrev{prec};

std::multiset<std::tuple<double, int>>::const_iterator it{this->orderedWindow.begin()};
double score;
int label;

for (; it != this->orderedWindow.end(); ++it) {
score = std::get<0>(*it);
label = std::get<1>(*it);

if (score != scorePrev) {
prec = tp / (double) (tp + fp);

if (precPrev > prec)
prec = precPrev; // Monotonic. decreasing

auc += this->trapzArea(tp, tpPrev, prec, precPrev);

scorePrev = score;
tpPrev = tp;
precPrev = prec;
}

if (label) tp--;
else fp--;
}

auc += this->trapzArea(tp, tpPrev, 1.0, precPrev);

return auc / this->positives; // Scale the x axis
}

void RollingPRAUC::insert(int label, double score) {
// Normalize label to 0 (negative) or 1 (positive)
int l = 0;
if (label == this->positiveLabel) {
l = 1;
this->positives++;
}

this->window.emplace_back(score, l);
this->orderedWindow.emplace(score, l);

return;
}

void RollingPRAUC::removeLast() {
std::tuple<double, int> last{this->window.front()};

if (std::get<1>(last))
this->positives--;

this->window.pop_front();

// Erase using a iterator to avoid multiple erases with equivalent instances
std::multiset<std::tuple<double, int>>::iterator it{
this->orderedWindow.find(last)
};
this->orderedWindow.erase(it);

return;
}

std::vector<int> RollingPRAUC::getTrueLabels() const {
std::vector<int> trueLabels;

std::deque<std::tuple<double, int>>::const_iterator it{this->window.begin()};
for (; it != this->window.end(); ++it)
trueLabels.push_back(std::get<1>(*it));

return trueLabels;
}

std::vector<double> RollingPRAUC::getScores() const {
std::vector<double> scores;

std::deque<std::tuple<double, int>>::const_iterator it{this->window.begin()};
for (; it != this->window.end(); ++it)
scores.push_back(std::get<0>(*it));

return scores;
}

double RollingPRAUC::trapzArea(double x1, double x2, double y1, double y2) const {
return abs(x1 - x2) * (y1 + y2) / 2;
}

} // namespace rollingprauc
59 changes: 59 additions & 0 deletions river/metrics/efficient_rollingprauc/cpp/RollingPRAUC.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#ifndef ROLLINGPRAUC_HPP
#define ROLLINGPRAUC_HPP

#include <deque>
#include <set>
#include <tuple>
#include <vector>

namespace rollingprauc {

class RollingPRAUC {
public:
RollingPRAUC();
RollingPRAUC(const int positiveLabel, const long unsigned windowSize);

virtual ~RollingPRAUC() = default;

// Calls insert() and removeLast() if needed
virtual void update(const int label, const double score);

// Erase the most recent instance with content equal to params
virtual void revert(const int label, const double score);

// Calculates the PRAUC and returns it
virtual double get() const;

// Returns y_true as a vector
virtual std::vector<int> getTrueLabels() const;

// Returns y_score as a vector
virtual std::vector<double> getScores() const;

private:
// Insert instance based on params
virtual void insert(const int label, const double score);

// Remove oldest instance
virtual void removeLast();

// Calculates the trapezoid area
double trapzArea(double x1, double x2, double y1, double y2) const;

int positiveLabel;

std::size_t windowSize;
std::size_t positives;

// window maintains a queue of the instances to store the temporal
// aspect of the stream. Using deque to allow revert()
std::deque<std::tuple<double, int>> window;

// orderedWindow maintains a multiset (implemented as a tree)
// to store the ordered instances
std::multiset<std::tuple<double, int>> orderedWindow;
};

} // namespace rollingprauc

#endif
13 changes: 13 additions & 0 deletions river/metrics/efficient_rollingprauc/efficient_rollingprauc.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from libcpp.vector cimport vector

cdef extern from "cpp/RollingPRAUC.cpp":
pass

cdef extern from "cpp/RollingPRAUC.hpp" namespace "rollingprauc":
cdef cppclass RollingPRAUC:
RollingPRAUC(int positiveLabel, int windowSize) except +
void update(int label, double score)
void revert(int label, double score)
double get()
vector[int] getTrueLabels()
vector[double] getScores()
55 changes: 55 additions & 0 deletions river/metrics/efficient_rollingprauc/efficient_rollingprauc.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# distutils: language = c++
# distutils: extra_compile_args = "-std=c++11"

import cython

from .efficient_rollingprauc cimport RollingPRAUC as CppRollingPRAUC

cdef class EfficientRollingPRAUC:
cdef cython.int positiveLabel
cdef cython.ulong windowSize
cdef CppRollingPRAUC* rollingprauc

def __cinit__(self, cython.int positiveLabel, cython.ulong windowSize):
self.positiveLabel = positiveLabel
self.windowSize = windowSize
self.rollingprauc = new CppRollingPRAUC(positiveLabel, windowSize)

def __dealloc__(self):
if not self.rollingprauc == NULL:
del self.rollingprauc

def update(self, label, score):
self.rollingprauc.update(label, score)

def revert(self, label, score):
self.rollingprauc.revert(label, score)

def get(self):
return self.rollingprauc.get()

def __getnewargs_ex__(self):
# Pickle will use this function to pass the arguments to __new__
return (self.positiveLabel, self.windowSize),{}

def __getstate__(self):
"""
On pickling, the true labels and scores of the instances in the
window will be dumped
"""
return (self.rollingprauc.getTrueLabels(), self.rollingprauc.getScores())

def __setstate__(self, state):
"""
On unpickling, the state parameter will have the true labels
and scores, this function updates the rollingprauc with them
"""

# Labels returned by __getstate__ are normalized (0 or 1)
labels, scores = state

for label, score in zip(labels, scores):
# If label is 1, update with the positive label defined by the constructor
# Else, update with a negative label
l = self.positiveLabel if label else int(not self.positiveLabel)
self.update(l, score)

0 comments on commit dfecc51

Please sign in to comment.