Skip to content

Commit

Permalink
Add LOESS (#681)
Browse files Browse the repository at this point in the history
* Add LOESS

* Add LOWESS test and bit improve

* Fix class name
  • Loading branch information
ishii-norimi authored Nov 8, 2023
1 parent db31de4 commit 3195069
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 11 deletions.
1 change: 1 addition & 0 deletions js/model_selector.js
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ const AIMethods = [
{ value: 'poisson', title: 'Poisson' },
{ value: 'segmented', title: 'Segmented' },
{ value: 'lowess', title: 'LOWESS' },
{ value: 'loess', title: 'LOESS' },
{ value: 'spline', title: 'Spline' },
{ value: 'gaussian_process', title: 'Gaussian Process' },
{ value: 'pcr', title: 'Principal Components' },
Expand Down
14 changes: 14 additions & 0 deletions js/view/loess.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import LOESS from '../../lib/model/loess.js'
import Controller from '../controller.js'

export default function (platform) {
platform.setting.ml.usage = 'Click and add data point. Next, click "Fit" button.'
const controller = new Controller(platform)
const fitModel = () => {
const model = new LOESS()
model.fit(platform.trainInput, platform.trainOutput)
platform.testResult(model.predict(platform.testInput(10)))
}

controller.input.button('Fit').on('click', () => fitModel())
}
17 changes: 6 additions & 11 deletions js/view/lowess.js
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import LOWESS from '../../lib/model/lowess.js'
import Controller from '../controller.js'

var dispLOWESS = function (elm, platform) {
const fitModel = cb => {
export default function (platform) {
platform.setting.ml.usage = 'Click and add data point. Next, click "Fit" button.'
const controller = new Controller(platform)
const fitModel = () => {
const model = new LOWESS()
model.fit(platform.trainInput, platform.trainOutput)
platform.testResult(model.predict(platform.testInput(10)))
}

elm.append('input')
.attr('type', 'button')
.attr('value', 'Fit')
.on('click', () => fitModel())
}

export default function (platform) {
platform.setting.ml.usage = 'Click and add data point. Next, click "Fit" button.'
dispLOWESS(platform.setting.ml.configElement, platform)
controller.input.button('Fit').on('click', () => fitModel())
}
62 changes: 62 additions & 0 deletions lib/model/loess.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import Matrix from '../util/matrix.js'

/**
* Locally estimated scatterplot smoothing
*/
export default class LOESS {
// https://en.wikipedia.org/wiki/Local_regression
// https://github.com/arokem/lowess
// https://jp.mathworks.com/help/curvefit/smoothing-data_ja_JP.html
constructor() {
this._k = (a, b) => {
const d = Matrix.sub(a, b)
d.map(v => v * v)
const s = d.sum(1)
s.map(v => (v <= 1 ? (1 - Math.sqrt(v) ** 3) ** 3 : 0))
return s
}
}

/**
* Fit model.
*
* @param {Array<Array<number>>} x Training data
* @param {Array<Array<number>>} y Target values
*/
fit(x, y) {
this._x = Matrix.fromArray(x)
this._b = Matrix.resize(this._x, this._x.rows, this._x.cols * 2 + 1, 1)
this._b.set(
0,
this._x.cols,
Matrix.map(this._x, v => v ** 2)
)
this._y = Matrix.fromArray(y)
}

/**
* Returns predicted values.
*
* @param {Array<Array<number>>} x Sample data
* @returns {Array<Array<number>>} Predicted values
*/
predict(x) {
x = Matrix.fromArray(x)
const pred = []
for (let i = 0; i < x.rows; i++) {
const xi = x.row(i)
const w = this._k(this._x, xi)
const bw = Matrix.mult(this._b, w)

const p = bw.tDot(this._b).solve(bw.tDot(this._y))
const rx = Matrix.resize(xi, xi.rows, xi.cols * 2 + 1, 1)
rx.set(
0,
xi.cols,
Matrix.map(xi, v => v ** 2)
)
pred.push(rx.dot(p).value)
}
return pred
}
}
42 changes: 42 additions & 0 deletions tests/gui/view/loess.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { getPage } from '../helper/browser'

describe('regression', () => {
/** @type {Awaited<ReturnType<getPage>>} */
let page
beforeEach(async () => {
page = await getPage()
}, 10000)

afterEach(async () => {
await page?.close()
})

test('initialize', async () => {
const taskSelectBox = await page.waitForSelector('#ml_selector dl:first-child dd:nth-child(5) select')
await taskSelectBox.selectOption('RG')
const modelSelectBox = await page.waitForSelector('#ml_selector .model_selection #mlDisp')
await modelSelectBox.selectOption('loess')
const methodMenu = await page.waitForSelector('#ml_selector #method_menu')
const buttons = await methodMenu.waitForSelector('.buttons')

const fit = await buttons.waitForSelector('input:nth-of-type(1)')
await expect((await fit.getProperty('value')).jsonValue()).resolves.toBe('Fit')
}, 10000)

test('learn', async () => {
const taskSelectBox = await page.waitForSelector('#ml_selector dl:first-child dd:nth-child(5) select')
await taskSelectBox.selectOption('RG')
const modelSelectBox = await page.waitForSelector('#ml_selector .model_selection #mlDisp')
await modelSelectBox.selectOption('loess')
const methodMenu = await page.waitForSelector('#ml_selector #method_menu')
const buttons = await methodMenu.waitForSelector('.buttons')

const methodFooter = await page.waitForSelector('#method_footer', { state: 'attached' })
await expect(methodFooter.evaluate(el => el.textContent)).resolves.toBe('')

const initButton = await buttons.waitForSelector('input[value=Fit]')
await initButton.evaluate(el => el.click())

await expect(methodFooter.evaluate(el => el.textContent)).resolves.toMatch(/^RMSE:[0-9.]+$/)
}, 10000)
})
42 changes: 42 additions & 0 deletions tests/gui/view/lowess.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { getPage } from '../helper/browser'

describe('regression', () => {
/** @type {Awaited<ReturnType<getPage>>} */
let page
beforeEach(async () => {
page = await getPage()
}, 10000)

afterEach(async () => {
await page?.close()
})

test('initialize', async () => {
const taskSelectBox = await page.waitForSelector('#ml_selector dl:first-child dd:nth-child(5) select')
await taskSelectBox.selectOption('RG')
const modelSelectBox = await page.waitForSelector('#ml_selector .model_selection #mlDisp')
await modelSelectBox.selectOption('lowess')
const methodMenu = await page.waitForSelector('#ml_selector #method_menu')
const buttons = await methodMenu.waitForSelector('.buttons')

const fit = await buttons.waitForSelector('input:nth-of-type(1)')
await expect((await fit.getProperty('value')).jsonValue()).resolves.toBe('Fit')
}, 10000)

test('learn', async () => {
const taskSelectBox = await page.waitForSelector('#ml_selector dl:first-child dd:nth-child(5) select')
await taskSelectBox.selectOption('RG')
const modelSelectBox = await page.waitForSelector('#ml_selector .model_selection #mlDisp')
await modelSelectBox.selectOption('lowess')
const methodMenu = await page.waitForSelector('#ml_selector #method_menu')
const buttons = await methodMenu.waitForSelector('.buttons')

const methodFooter = await page.waitForSelector('#method_footer', { state: 'attached' })
await expect(methodFooter.evaluate(el => el.textContent)).resolves.toBe('')

const initButton = await buttons.waitForSelector('input[value=Fit]')
await initButton.evaluate(el => el.click())

await expect(methodFooter.evaluate(el => el.textContent)).resolves.toMatch(/^RMSE:[0-9.]+$/)
}, 10000)
})
20 changes: 20 additions & 0 deletions tests/lib/model/loess.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { jest } from '@jest/globals'
jest.retryTimes(3)

import Matrix from '../../../lib/util/matrix.js'
import LOESS from '../../../lib/model/loess.js'

import { rmse } from '../../../lib/evaluate/regression.js'

test('fit', () => {
const model = new LOESS()
const x = Matrix.random(50, 2, -2, 2).toArray()
const t = []
for (let i = 0; i < x.length; i++) {
t[i] = [x[i][0] + x[i][1] + (Math.random() - 0.5) / 20]
}
model.fit(x, t)
const y = model.predict(x)
const err = rmse(y, t)[0]
expect(err).toBeLessThan(0.5)
})

0 comments on commit 3195069

Please sign in to comment.