From 976cba0b28da7974cfbad29f6f5702e8f68fba03 Mon Sep 17 00:00:00 2001 From: ishii-norimi Date: Mon, 20 Nov 2023 20:11:52 +0900 Subject: [PATCH] Add Multiclass Ridge classifier --- README.md | 2 +- js/view/least_square.js | 1 + js/view/ridge.js | 109 ++++++++++++++-------------------- lib/model/ridge.js | 67 +++++++++++++++++++++ tests/lib/model/ridge.test.js | 38 +++++++++++- 5 files changed, 152 insertions(+), 65 deletions(-) diff --git a/README.md b/README.md index 0c3f6953..cbac3587 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,7 @@ for (let i = 0; i < n; i++) { | task | model | | ---- | ----- | | clustering | (Soft / Kernel / Genetic / Weighted) k-means, k-means++, k-medois, k-medians, x-means, G-means, LBG, ISODATA, Fuzzy c-means, Possibilistic c-means, Agglomerative (complete linkage, single linkage, group average, Ward's, centroid, weighted average, median), DIANA, Monothetic, Mutual kNN, Mean shift, DBSCAN, OPTICS, HDBSCAN, DENCLUE, DBCLASD, CLUES, PAM, CLARA, CLARANS, BIRCH, CURE, ROCK, C2P, PLSA, Latent dirichlet allocation, GMM, VBGMM, Affinity propagation, Spectral clustering, Mountain, (Growing) SOM, GTM, (Growing) Neural gas, Growing cell structures, LVQ, ART, SVC, CAST, CHAMELEON, COLL, CLIQUE, PROCLUS, ORCLUS, FINDIT, NMF, Autoencoder | -| classification | (Fisher's) Linear discriminant, Quadratic discriminant, Mixture discriminant, Least squares, Ridge, (Complement / Negation / Universal-set / Selective) Naive Bayes (gaussian), AODE, (Fuzzy / Weighted) k-nearest neighbor, Radius neighbor, Nearest centroid, ENN, ENaN, NNBCA, ADAMENN, DANN, IKNN, Decision tree, Random forest, Extra trees, GBDT, XGBoost, ALMA, (Aggressive) ROMMA, (Bounded) Online gradient descent, (Budgeted online) Passive aggressive, RLS, (Selective-sampling) Second order perceptron, AROW, NAROW, Confidence weighted, CELLIP, IELLIP, Normal herd, Stoptron, (Kernelized) Pegasos, MIRA, Forgetron, Projectron, Projectron++, Banditron, Ballseptron, (Multiclass) BSGD, ILK, SILK, (Multinomial) Logistic regression, (Multinomial) Probit, SVM, Gaussian process, HMM, CRF, Bayesian Network, LVQ, (Average / Multiclass / Voted / Kernelized / Selective-sampling / Margin / Shifting / Budget / Tighter / Tightest) Perceptron, PAUM, RBP, ADALINE, MADALINE, MLP, LMNN | +| classification | (Fisher's) Linear discriminant, Quadratic discriminant, Mixture discriminant, Least squares, (Multiclass / Kernel) Ridge, (Complement / Negation / Universal-set / Selective) Naive Bayes (gaussian), AODE, (Fuzzy / Weighted) k-nearest neighbor, Radius neighbor, Nearest centroid, ENN, ENaN, NNBCA, ADAMENN, DANN, IKNN, Decision tree, Random forest, Extra trees, GBDT, XGBoost, ALMA, (Aggressive) ROMMA, (Bounded) Online gradient descent, (Budgeted online) Passive aggressive, RLS, (Selective-sampling) Second order perceptron, AROW, NAROW, Confidence weighted, CELLIP, IELLIP, Normal herd, Stoptron, (Kernelized) Pegasos, MIRA, Forgetron, Projectron, Projectron++, Banditron, Ballseptron, (Multiclass) BSGD, ILK, SILK, (Multinomial) Logistic regression, (Multinomial) Probit, SVM, Gaussian process, HMM, CRF, Bayesian Network, LVQ, (Average / Multiclass / Voted / Kernelized / Selective-sampling / Margin / Shifting / Budget / Tighter / Tightest) Perceptron, PAUM, RBP, ADALINE, MADALINE, MLP, LMNN | | semi-supervised classification | k-nearest neighbor, Radius neighbor, Label propagation, Label spreading, k-means, GMM, S3VM, Ladder network | | regression | Least squares, Ridge, Lasso, Elastic net, RLS, Bayesian linear, Poisson, Least absolute deviations, Huber, Tukey, Least trimmed squares, Least median squares, Lp norm linear, SMA, Deming, Segmented, LOWESS, LOESS, spline, Gaussian process, Principal components, Partial least squares, Projection pursuit, Quantile regression, k-nearest neighbor, Radius neighbor, IDW, Nadaraya Watson, Priestley Chao, Gasser Muller, RBF Network, RVM, Decision tree, Random forest, Extra trees, GBDT, XGBoost, SVR, MLP, GMR, Isotonic, Ramer Douglas Peucker, Theil-Sen, Passing-Bablok, Repeated median | | interpolation | Nearest neighbor, IDW, (Spherical) Linear, Brahmagupta, Logarithmic, Cosine, (Inverse) Smoothstep, Cubic, (Centripetal) Catmull-Rom, Hermit, Polynomial, Lagrange, Trigonometric, Spline, RBF Network, Akima, Natural neighbor, Delaunay | diff --git a/js/view/least_square.js b/js/view/least_square.js index 9564d183..ff95299b 100644 --- a/js/view/least_square.js +++ b/js/view/least_square.js @@ -56,6 +56,7 @@ export class BasisFunctions { } makeHtml(r) { + r = d3.select(r) if (!this._e) { this._e = r.append('div').attr('id', `ls_model_${this._name}`) } else { diff --git a/js/view/ridge.js b/js/view/ridge.js index 78cdbbd3..32322bca 100644 --- a/js/view/ridge.js +++ b/js/view/ridge.js @@ -1,31 +1,52 @@ import Matrix from '../../lib/util/matrix.js' import { BasisFunctions } from './least_square.js' -import { Ridge, KernelRidge } from '../../lib/model/ridge.js' +import { Ridge, MulticlassRidge, KernelRidge } from '../../lib/model/ridge.js' import EnsembleBinaryModel from '../../lib/model/ensemble_binary.js' +import Controller from '../controller.js' -var dispRidge = function (elm, platform) { +export default function (platform) { + platform.setting.ml.usage = 'Click and add data point. Next, click "Fit" button.' platform.setting.ml.reference = { title: 'Ridge regression (Wikipedia)', url: 'https://en.wikipedia.org/wiki/Ridge_regression', } + platform.setting.ml.detail = ` +The model form is +$$ +f(X) = X W + \\epsilon +$$ + +The loss function can be written as +$$ +L(W) = \\| X W - y \\|^2 + \\lambda \\| W \\|^2 +$$ +where $ y $ is the observed value corresponding to $ X $. +Therefore, the optimum parameter $ \\hat{W} $ is estimated as +$$ +\\hat{W} = \\left( X^T X + \\lambda I \\right)^{-1} X^T y +$$ +` + const controller = new Controller(platform) const task = platform.task - const fitModel = cb => { + const fitModel = () => { const dim = platform.datas.dimension - const kernel = elm.select('[name=kernel]').property('value') - const kernelName = kernel === 'no kernel' ? null : kernel + const kernelName = kernel.value === 'no kernel' ? null : kernel.value let model - const l = +elm.select('[name=lambda]').property('value') + const l = +lambda.value if (task === 'CF') { - const method = elm.select('[name=method]').property('value') if (kernelName) { model = new EnsembleBinaryModel(function () { return new KernelRidge(l, kernelName) - }, method) + }, method.value) } else { - model = new EnsembleBinaryModel(function () { - return new Ridge(l) - }, method) + if (method.value === 'multiclass') { + model = new MulticlassRidge(l) + } else { + model = new EnsembleBinaryModel(function () { + return new Ridge(l) + }, method.value) + } } } else { if (kernelName) { @@ -53,61 +74,23 @@ var dispRidge = function (elm, platform) { } const basisFunction = new BasisFunctions(platform) + let method = null if (task === 'CF') { - elm.append('select') - .attr('name', 'method') - .selectAll('option') - .data(['oneone', 'onerest']) - .enter() - .append('option') - .property('value', d => d) - .text(d => d) + method = controller.select(['oneone', 'onerest', 'multiclass']).on('change', () => { + if (method.value === 'multiclass') { + kernel.element.style.display = 'none' + } else { + kernel.element.style.display = null + } + }) } + let kernel = null if (task !== 'FS') { - basisFunction.makeHtml(elm) - elm.append('select') - .attr('name', 'kernel') - .selectAll('option') - .data(['no kernel', 'gaussian']) - .enter() - .append('option') - .property('value', d => d) - .text(d => d) + basisFunction.makeHtml(controller.element) + kernel = controller.select(['no kernel', 'gaussian']) } else { - elm.append('input').attr('type', 'hidden').attr('name', 'kernel').property('value', '') + kernel = controller.input({ type: 'hidden', value: '' }) } - elm.append('span').text('lambda = ') - elm.append('select') - .attr('name', 'lambda') - .selectAll('option') - .data([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100]) - .enter() - .append('option') - .property('value', d => d) - .text(d => d) - elm.append('input') - .attr('type', 'button') - .attr('value', 'Fit') - .on('click', () => fitModel()) -} - -export default function (platform) { - platform.setting.ml.usage = 'Click and add data point. Next, click "Fit" button.' - dispRidge(platform.setting.ml.configElement, platform) - platform.setting.ml.detail = ` -The model form is -$$ -f(X) = X W + \\epsilon -$$ - -The loss function can be written as -$$ -L(W) = \\| X W - y \\|^2 + \\lambda \\| W \\|^2 -$$ -where $ y $ is the observed value corresponding to $ X $. -Therefore, the optimum parameter $ \\hat{W} $ is estimated as -$$ -\\hat{W} = \\left( X^T X + \\lambda I \\right)^{-1} X^T y -$$ -` + const lambda = controller.select({ label: 'lambda = ', values: [0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100] }) + controller.input.button('Fit').on('click', () => fitModel()) } diff --git a/lib/model/ridge.js b/lib/model/ridge.js index 88318c78..f9ac6892 100644 --- a/lib/model/ridge.js +++ b/lib/model/ridge.js @@ -50,6 +50,73 @@ export class Ridge { } } +/** + * Multiclass ridge regressioin + */ +export class MulticlassRidge { + /** + * @param {number} [lambda=0.1] Regularization strength + */ + constructor(lambda = 0.1) { + this._w = null + this._lambda = lambda + this._classes = [] + } + + /** + * Category list + * + * @type {*[]} + */ + get categories() { + return this._classes + } + + /** + * Fit model. + * + * @param {Array>} x Training data + * @param {*[]} y Target values + */ + fit(x, y) { + x = Matrix.fromArray(x) + this._classes = [...new Set(y)] + const p = new Matrix(y.length, this._classes.length, -1) + for (let i = 0; i < y.length; i++) { + p.set(i, this._classes.indexOf(y[i]), 1) + } + const xtx = x.tDot(x) + for (let i = 0; i < xtx.rows; i++) { + xtx.addAt(i, i, this._lambda) + } + + this._w = xtx.solve(x.t).dot(p) + } + + /** + * Returns predicted values. + * + * @param {Array>} x Sample data + * @returns {*[]} Predicted values + */ + predict(x) { + x = Matrix.fromArray(x) + return x + .dot(this._w) + .argmax(1) + .value.map(i => this._classes[i]) + } + + /** + * Returns importances of the features. + * + * @returns {number[]} Importances + */ + importance() { + return this._w.toArray() + } +} + /** * Kernel ridge regression */ diff --git a/tests/lib/model/ridge.test.js b/tests/lib/model/ridge.test.js index 8b33da4c..40c4c5b4 100644 --- a/tests/lib/model/ridge.test.js +++ b/tests/lib/model/ridge.test.js @@ -2,9 +2,10 @@ import { jest } from '@jest/globals' jest.retryTimes(3) import Matrix from '../../../lib/util/matrix.js' -import { Ridge, KernelRidge } from '../../../lib/model/ridge.js' +import { Ridge, MulticlassRidge, KernelRidge } from '../../../lib/model/ridge.js' import { rmse } from '../../../lib/evaluate/regression.js' +import { accuracy } from '../../../lib/evaluate/classification.js' describe('ridge', () => { test('default', () => { @@ -40,6 +41,41 @@ describe('ridge', () => { }) }) +describe('multiclass ridge', () => { + test('default', () => { + const model = new MulticlassRidge() + expect(model._lambda).toBe(0.1) + }) + + test('fit', () => { + const model = new MulticlassRidge(0.001) + const x = Matrix.concat(Matrix.randn(50, 2, 0, 0.2), Matrix.randn(50, 2, [0, 5], 0.2)).toArray() + const t = [] + for (let i = 0; i < x.length; i++) { + t[i] = String.fromCharCode('a'.charCodeAt(0) + Math.floor(i / 50)) + } + model.fit(x, t) + const y = model.predict(x) + const acc = accuracy(y, t) + expect(acc).toBeGreaterThan(0.75) + }) + + test('importance', () => { + const model = new MulticlassRidge(0.01) + const x = Matrix.concat(Matrix.randn(50, 3, 0, 0.2), Matrix.randn(50, 3, 5, 0.2)).toArray() + const t = [] + for (let i = 0; i < x.length; i++) { + t[i] = String.fromCharCode('a'.charCodeAt(0) + Math.floor(i / 50)) + } + model.fit(x, t) + const importance = model.importance() + expect(importance).toHaveLength(3) + expect(importance[0]).toHaveLength(2) + expect(importance[1]).toHaveLength(2) + expect(importance[2]).toHaveLength(2) + }) +}) + describe('kernel ridge', () => { test('default', () => { const model = new KernelRidge()