Skip to content

Commit

Permalink
Add Multiclass Ridge classifier (#697)
Browse files Browse the repository at this point in the history
* Add Multiclass Ridge classifier

* Fix test

* Fix least square
  • Loading branch information
ishii-norimi authored Nov 21, 2023
1 parent d8d002b commit d5fac14
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 106 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ for (let i = 0; i < n; i++) {
| task | model |
| ---- | ----- |
| clustering | (Soft / Kernel / Genetic / Weighted) k-means, k-means++, k-medois, k-medians, x-means, G-means, LBG, ISODATA, Fuzzy c-means, Possibilistic c-means, Agglomerative (complete linkage, single linkage, group average, Ward's, centroid, weighted average, median), DIANA, Monothetic, Mutual kNN, Mean shift, DBSCAN, OPTICS, HDBSCAN, DENCLUE, DBCLASD, CLUES, PAM, CLARA, CLARANS, BIRCH, CURE, ROCK, C2P, PLSA, Latent dirichlet allocation, GMM, VBGMM, Affinity propagation, Spectral clustering, Mountain, (Growing) SOM, GTM, (Growing) Neural gas, Growing cell structures, LVQ, ART, SVC, CAST, CHAMELEON, COLL, CLIQUE, PROCLUS, ORCLUS, FINDIT, NMF, Autoencoder |
| classification | (Fisher's) Linear discriminant, Quadratic discriminant, Mixture discriminant, Least squares, Ridge, (Complement / Negation / Universal-set / Selective) Naive Bayes (gaussian), AODE, (Fuzzy / Weighted) k-nearest neighbor, Radius neighbor, Nearest centroid, ENN, ENaN, NNBCA, ADAMENN, DANN, IKNN, Decision tree, Random forest, Extra trees, GBDT, XGBoost, ALMA, (Aggressive) ROMMA, (Bounded) Online gradient descent, (Budgeted online) Passive aggressive, RLS, (Selective-sampling) Second order perceptron, AROW, NAROW, Confidence weighted, CELLIP, IELLIP, Normal herd, Stoptron, (Kernelized) Pegasos, MIRA, Forgetron, Projectron, Projectron++, Banditron, Ballseptron, (Multiclass) BSGD, ILK, SILK, (Multinomial) Logistic regression, (Multinomial) Probit, SVM, Gaussian process, HMM, CRF, Bayesian Network, LVQ, (Average / Multiclass / Voted / Kernelized / Selective-sampling / Margin / Shifting / Budget / Tighter / Tightest) Perceptron, PAUM, RBP, ADALINE, MADALINE, MLP, LMNN |
| classification | (Fisher's) Linear discriminant, Quadratic discriminant, Mixture discriminant, Least squares, (Multiclass / Kernel) Ridge, (Complement / Negation / Universal-set / Selective) Naive Bayes (gaussian), AODE, (Fuzzy / Weighted) k-nearest neighbor, Radius neighbor, Nearest centroid, ENN, ENaN, NNBCA, ADAMENN, DANN, IKNN, Decision tree, Random forest, Extra trees, GBDT, XGBoost, ALMA, (Aggressive) ROMMA, (Bounded) Online gradient descent, (Budgeted online) Passive aggressive, RLS, (Selective-sampling) Second order perceptron, AROW, NAROW, Confidence weighted, CELLIP, IELLIP, Normal herd, Stoptron, (Kernelized) Pegasos, MIRA, Forgetron, Projectron, Projectron++, Banditron, Ballseptron, (Multiclass) BSGD, ILK, SILK, (Multinomial) Logistic regression, (Multinomial) Probit, SVM, Gaussian process, HMM, CRF, Bayesian Network, LVQ, (Average / Multiclass / Voted / Kernelized / Selective-sampling / Margin / Shifting / Budget / Tighter / Tightest) Perceptron, PAUM, RBP, ADALINE, MADALINE, MLP, LMNN |
| semi-supervised classification | k-nearest neighbor, Radius neighbor, Label propagation, Label spreading, k-means, GMM, S3VM, Ladder network |
| regression | Least squares, Ridge, Lasso, Elastic net, RLS, Bayesian linear, Poisson, Least absolute deviations, Huber, Tukey, Least trimmed squares, Least median squares, Lp norm linear, SMA, Deming, Segmented, LOWESS, LOESS, spline, Gaussian process, Principal components, Partial least squares, Projection pursuit, Quantile regression, k-nearest neighbor, Radius neighbor, IDW, Nadaraya Watson, Priestley Chao, Gasser Muller, RBF Network, RVM, Decision tree, Random forest, Extra trees, GBDT, XGBoost, SVR, MLP, GMR, Isotonic, Ramer Douglas Peucker, Theil-Sen, Passing-Bablok, Repeated median |
| interpolation | Nearest neighbor, IDW, (Spherical) Linear, Brahmagupta, Logarithmic, Cosine, (Inverse) Smoothstep, Cubic, (Centripetal) Catmull-Rom, Hermit, Polynomial, Lagrange, Trigonometric, Spline, RBF Network, Akima, Natural neighbor, Delaunay |
Expand Down
63 changes: 26 additions & 37 deletions js/view/least_square.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import Matrix from '../../lib/util/matrix.js'
import LeastSquares from '../../lib/model/least_square.js'
import stringToFunction from '../expression.js'
import EnsembleBinaryModel from '../../lib/model/ensemble_binary.js'
import Controller from '../controller.js'

const combination_repetition = (n, k) => {
const c = []
Expand Down Expand Up @@ -56,6 +57,7 @@ export class BasisFunctions {
}

makeHtml(r) {
r = d3.select(r)
if (!this._e) {
this._e = r.append('div').attr('id', `ls_model_${this._name}`)
} else {
Expand Down Expand Up @@ -160,47 +162,12 @@ export class BasisFunctions {
}
}

var dispLeastSquares = function (elm, platform) {
export default function (platform) {
platform.setting.ml.usage = 'Click and add data point. Next, click "Fit" button.'
platform.setting.ml.reference = {
title: 'Least squares (Wikipedia)',
url: 'https://en.wikipedia.org/wiki/Least_squares',
}
const fitModel = () => {
let model
if (platform.task === 'CF') {
const method = elm.select('[name=method]').property('value')
model = new EnsembleBinaryModel(LeastSquares, method)
} else {
model = new LeastSquares()
}
model.fit(basisFunctions.apply(platform.trainInput).toArray(), platform.trainOutput)

let pred = model.predict(basisFunctions.apply(platform.testInput(2)).toArray())
platform.testResult(pred)
}

if (platform.task === 'CF') {
elm.append('select')
.attr('name', 'method')
.selectAll('option')
.data(['oneone', 'onerest'])
.enter()
.append('option')
.property('value', d => d)
.text(d => d)
}
const basisFunctions = new BasisFunctions(platform)
basisFunctions.makeHtml(elm)

elm.append('input')
.attr('type', 'button')
.attr('value', 'Fit')
.on('click', () => fitModel())
}

export default function (platform) {
platform.setting.ml.usage = 'Click and add data point. Next, click "Fit" button.'
dispLeastSquares(platform.setting.ml.configElement, platform)
platform.setting.ml.detail = `
The model form is
$$
Expand All @@ -218,4 +185,26 @@ $$
$$
where $ G_{ij} = g_i(x_j) $.
`
const controller = new Controller(platform)
const fitModel = () => {
let model
if (platform.task === 'CF') {
model = new EnsembleBinaryModel(LeastSquares, method.value)
} else {
model = new LeastSquares()
}
model.fit(basisFunctions.apply(platform.trainInput).toArray(), platform.trainOutput)

let pred = model.predict(basisFunctions.apply(platform.testInput(2)).toArray())
platform.testResult(pred)
}

let method = null
if (platform.task === 'CF') {
method = controller.select(['oneone', 'onerest'])
}
const basisFunctions = new BasisFunctions(platform)
basisFunctions.makeHtml(controller.element)

controller.input.button('Fit').on('click', () => fitModel())
}
109 changes: 46 additions & 63 deletions js/view/ridge.js
Original file line number Diff line number Diff line change
@@ -1,31 +1,52 @@
import Matrix from '../../lib/util/matrix.js'

import { BasisFunctions } from './least_square.js'
import { Ridge, KernelRidge } from '../../lib/model/ridge.js'
import { Ridge, MulticlassRidge, KernelRidge } from '../../lib/model/ridge.js'
import EnsembleBinaryModel from '../../lib/model/ensemble_binary.js'
import Controller from '../controller.js'

var dispRidge = function (elm, platform) {
export default function (platform) {
platform.setting.ml.usage = 'Click and add data point. Next, click "Fit" button.'
platform.setting.ml.reference = {
title: 'Ridge regression (Wikipedia)',
url: 'https://en.wikipedia.org/wiki/Ridge_regression',
}
platform.setting.ml.detail = `
The model form is
$$
f(X) = X W + \\epsilon
$$
The loss function can be written as
$$
L(W) = \\| X W - y \\|^2 + \\lambda \\| W \\|^2
$$
where $ y $ is the observed value corresponding to $ X $.
Therefore, the optimum parameter $ \\hat{W} $ is estimated as
$$
\\hat{W} = \\left( X^T X + \\lambda I \\right)^{-1} X^T y
$$
`
const controller = new Controller(platform)
const task = platform.task
const fitModel = cb => {
const fitModel = () => {
const dim = platform.datas.dimension
const kernel = elm.select('[name=kernel]').property('value')
const kernelName = kernel === 'no kernel' ? null : kernel
const kernelName = kernel.value === 'no kernel' ? null : kernel.value
let model
const l = +elm.select('[name=lambda]').property('value')
const l = +lambda.value
if (task === 'CF') {
const method = elm.select('[name=method]').property('value')
if (kernelName) {
model = new EnsembleBinaryModel(function () {
return new KernelRidge(l, kernelName)
}, method)
}, method.value)
} else {
model = new EnsembleBinaryModel(function () {
return new Ridge(l)
}, method)
if (method.value === 'multiclass') {
model = new MulticlassRidge(l)
} else {
model = new EnsembleBinaryModel(function () {
return new Ridge(l)
}, method.value)
}
}
} else {
if (kernelName) {
Expand Down Expand Up @@ -53,61 +74,23 @@ var dispRidge = function (elm, platform) {
}

const basisFunction = new BasisFunctions(platform)
let method = null
if (task === 'CF') {
elm.append('select')
.attr('name', 'method')
.selectAll('option')
.data(['oneone', 'onerest'])
.enter()
.append('option')
.property('value', d => d)
.text(d => d)
method = controller.select(['oneone', 'onerest', 'multiclass']).on('change', () => {
if (method.value === 'multiclass') {
kernel.element.style.display = 'none'
} else {
kernel.element.style.display = null
}
})
}
let kernel = null
if (task !== 'FS') {
basisFunction.makeHtml(elm)
elm.append('select')
.attr('name', 'kernel')
.selectAll('option')
.data(['no kernel', 'gaussian'])
.enter()
.append('option')
.property('value', d => d)
.text(d => d)
basisFunction.makeHtml(controller.element)
kernel = controller.select(['no kernel', 'gaussian'])
} else {
elm.append('input').attr('type', 'hidden').attr('name', 'kernel').property('value', '')
kernel = controller.input({ type: 'hidden', value: '' })
}
elm.append('span').text('lambda = ')
elm.append('select')
.attr('name', 'lambda')
.selectAll('option')
.data([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100])
.enter()
.append('option')
.property('value', d => d)
.text(d => d)
elm.append('input')
.attr('type', 'button')
.attr('value', 'Fit')
.on('click', () => fitModel())
}

export default function (platform) {
platform.setting.ml.usage = 'Click and add data point. Next, click "Fit" button.'
dispRidge(platform.setting.ml.configElement, platform)
platform.setting.ml.detail = `
The model form is
$$
f(X) = X W + \\epsilon
$$
The loss function can be written as
$$
L(W) = \\| X W - y \\|^2 + \\lambda \\| W \\|^2
$$
where $ y $ is the observed value corresponding to $ X $.
Therefore, the optimum parameter $ \\hat{W} $ is estimated as
$$
\\hat{W} = \\left( X^T X + \\lambda I \\right)^{-1} X^T y
$$
`
const lambda = controller.select({ label: 'lambda = ', values: [0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100] })
controller.input.button('Fit').on('click', () => fitModel())
}
67 changes: 67 additions & 0 deletions lib/model/ridge.js
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,73 @@ export class Ridge {
}
}

/**
* Multiclass ridge regressioin
*/
export class MulticlassRidge {
/**
* @param {number} [lambda=0.1] Regularization strength
*/
constructor(lambda = 0.1) {
this._w = null
this._lambda = lambda
this._classes = []
}

/**
* Category list
*
* @type {*[]}
*/
get categories() {
return this._classes
}

/**
* Fit model.
*
* @param {Array<Array<number>>} x Training data
* @param {*[]} y Target values
*/
fit(x, y) {
x = Matrix.fromArray(x)
this._classes = [...new Set(y)]
const p = new Matrix(y.length, this._classes.length, -1)
for (let i = 0; i < y.length; i++) {
p.set(i, this._classes.indexOf(y[i]), 1)
}
const xtx = x.tDot(x)
for (let i = 0; i < xtx.rows; i++) {
xtx.addAt(i, i, this._lambda)
}

this._w = xtx.solve(x.t).dot(p)
}

/**
* Returns predicted values.
*
* @param {Array<Array<number>>} x Sample data
* @returns {*[]} Predicted values
*/
predict(x) {
x = Matrix.fromArray(x)
return x
.dot(this._w)
.argmax(1)
.value.map(i => this._classes[i])
}

/**
* Returns importances of the features.
*
* @returns {number[]} Importances
*/
importance() {
return this._w.toArray()
}
}

/**
* Kernel ridge regression
*/
Expand Down
2 changes: 1 addition & 1 deletion tests/gui/view/least_square.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ describe('classification', () => {
const methodMenu = await page.waitForSelector('#ml_selector #method_menu')
const buttons = await methodMenu.waitForSelector('.buttons')

const methods = await buttons.waitForSelector('[name=method]')
const methods = await buttons.waitForSelector('select:nth-of-type(1)')
await expect((await methods.getProperty('value')).jsonValue()).resolves.toBe('oneone')
const preset = await buttons.waitForSelector('[name=preset]')
await expect((await preset.getProperty('value')).jsonValue()).resolves.toBe('linear')
Expand Down
6 changes: 3 additions & 3 deletions tests/gui/view/ridge.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ describe('classification', () => {
const methodMenu = await page.waitForSelector('#ml_selector #method_menu')
const buttons = await methodMenu.waitForSelector('.buttons')

const methods = await buttons.waitForSelector('[name=method]')
const methods = await buttons.waitForSelector('select:nth-of-type(1)')
await expect((await methods.getProperty('value')).jsonValue()).resolves.toBe('oneone')
const preset = await buttons.waitForSelector('[name=preset]')
await expect((await preset.getProperty('value')).jsonValue()).resolves.toBe('linear')
const kernel = await buttons.waitForSelector('[name=kernel]')
const kernel = await buttons.waitForSelector('select:nth-of-type(2)')
await expect((await kernel.getProperty('value')).jsonValue()).resolves.toBe('no kernel')
const lambda = await buttons.waitForSelector('[name=lambda]')
const lambda = await buttons.waitForSelector('select:nth-of-type(3)')
await expect((await lambda.getProperty('value')).jsonValue()).resolves.toBe('0')
}, 10000)

Expand Down
38 changes: 37 additions & 1 deletion tests/lib/model/ridge.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ import { jest } from '@jest/globals'
jest.retryTimes(3)

import Matrix from '../../../lib/util/matrix.js'
import { Ridge, KernelRidge } from '../../../lib/model/ridge.js'
import { Ridge, MulticlassRidge, KernelRidge } from '../../../lib/model/ridge.js'

import { rmse } from '../../../lib/evaluate/regression.js'
import { accuracy } from '../../../lib/evaluate/classification.js'

describe('ridge', () => {
test('default', () => {
Expand Down Expand Up @@ -40,6 +41,41 @@ describe('ridge', () => {
})
})

describe('multiclass ridge', () => {
test('default', () => {
const model = new MulticlassRidge()
expect(model._lambda).toBe(0.1)
})

test('fit', () => {
const model = new MulticlassRidge(0.001)
const x = Matrix.concat(Matrix.randn(50, 2, 0, 0.2), Matrix.randn(50, 2, [0, 5], 0.2)).toArray()
const t = []
for (let i = 0; i < x.length; i++) {
t[i] = String.fromCharCode('a'.charCodeAt(0) + Math.floor(i / 50))
}
model.fit(x, t)
const y = model.predict(x)
const acc = accuracy(y, t)
expect(acc).toBeGreaterThan(0.75)
})

test('importance', () => {
const model = new MulticlassRidge(0.01)
const x = Matrix.concat(Matrix.randn(50, 3, 0, 0.2), Matrix.randn(50, 3, 5, 0.2)).toArray()
const t = []
for (let i = 0; i < x.length; i++) {
t[i] = String.fromCharCode('a'.charCodeAt(0) + Math.floor(i / 50))
}
model.fit(x, t)
const importance = model.importance()
expect(importance).toHaveLength(3)
expect(importance[0]).toHaveLength(2)
expect(importance[1]).toHaveLength(2)
expect(importance[2]).toHaveLength(2)
})
})

describe('kernel ridge', () => {
test('default', () => {
const model = new KernelRidge()
Expand Down

0 comments on commit d5fac14

Please sign in to comment.