-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add LOESS * Add LOWESS test and bit improve * Fix class name
- Loading branch information
1 parent
db31de4
commit 3195069
Showing
7 changed files
with
187 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import LOESS from '../../lib/model/loess.js' | ||
import Controller from '../controller.js' | ||
|
||
export default function (platform) { | ||
platform.setting.ml.usage = 'Click and add data point. Next, click "Fit" button.' | ||
const controller = new Controller(platform) | ||
const fitModel = () => { | ||
const model = new LOESS() | ||
model.fit(platform.trainInput, platform.trainOutput) | ||
platform.testResult(model.predict(platform.testInput(10))) | ||
} | ||
|
||
controller.input.button('Fit').on('click', () => fitModel()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,14 @@ | ||
import LOWESS from '../../lib/model/lowess.js' | ||
import Controller from '../controller.js' | ||
|
||
var dispLOWESS = function (elm, platform) { | ||
const fitModel = cb => { | ||
export default function (platform) { | ||
platform.setting.ml.usage = 'Click and add data point. Next, click "Fit" button.' | ||
const controller = new Controller(platform) | ||
const fitModel = () => { | ||
const model = new LOWESS() | ||
model.fit(platform.trainInput, platform.trainOutput) | ||
platform.testResult(model.predict(platform.testInput(10))) | ||
} | ||
|
||
elm.append('input') | ||
.attr('type', 'button') | ||
.attr('value', 'Fit') | ||
.on('click', () => fitModel()) | ||
} | ||
|
||
export default function (platform) { | ||
platform.setting.ml.usage = 'Click and add data point. Next, click "Fit" button.' | ||
dispLOWESS(platform.setting.ml.configElement, platform) | ||
controller.input.button('Fit').on('click', () => fitModel()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import Matrix from '../util/matrix.js' | ||
|
||
/** | ||
* Locally estimated scatterplot smoothing | ||
*/ | ||
export default class LOESS { | ||
// https://en.wikipedia.org/wiki/Local_regression | ||
// https://github.com/arokem/lowess | ||
// https://jp.mathworks.com/help/curvefit/smoothing-data_ja_JP.html | ||
constructor() { | ||
this._k = (a, b) => { | ||
const d = Matrix.sub(a, b) | ||
d.map(v => v * v) | ||
const s = d.sum(1) | ||
s.map(v => (v <= 1 ? (1 - Math.sqrt(v) ** 3) ** 3 : 0)) | ||
return s | ||
} | ||
} | ||
|
||
/** | ||
* Fit model. | ||
* | ||
* @param {Array<Array<number>>} x Training data | ||
* @param {Array<Array<number>>} y Target values | ||
*/ | ||
fit(x, y) { | ||
this._x = Matrix.fromArray(x) | ||
this._b = Matrix.resize(this._x, this._x.rows, this._x.cols * 2 + 1, 1) | ||
this._b.set( | ||
0, | ||
this._x.cols, | ||
Matrix.map(this._x, v => v ** 2) | ||
) | ||
this._y = Matrix.fromArray(y) | ||
} | ||
|
||
/** | ||
* Returns predicted values. | ||
* | ||
* @param {Array<Array<number>>} x Sample data | ||
* @returns {Array<Array<number>>} Predicted values | ||
*/ | ||
predict(x) { | ||
x = Matrix.fromArray(x) | ||
const pred = [] | ||
for (let i = 0; i < x.rows; i++) { | ||
const xi = x.row(i) | ||
const w = this._k(this._x, xi) | ||
const bw = Matrix.mult(this._b, w) | ||
|
||
const p = bw.tDot(this._b).solve(bw.tDot(this._y)) | ||
const rx = Matrix.resize(xi, xi.rows, xi.cols * 2 + 1, 1) | ||
rx.set( | ||
0, | ||
xi.cols, | ||
Matrix.map(xi, v => v ** 2) | ||
) | ||
pred.push(rx.dot(p).value) | ||
} | ||
return pred | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import { getPage } from '../helper/browser' | ||
|
||
describe('regression', () => { | ||
/** @type {Awaited<ReturnType<getPage>>} */ | ||
let page | ||
beforeEach(async () => { | ||
page = await getPage() | ||
}, 10000) | ||
|
||
afterEach(async () => { | ||
await page?.close() | ||
}) | ||
|
||
test('initialize', async () => { | ||
const taskSelectBox = await page.waitForSelector('#ml_selector dl:first-child dd:nth-child(5) select') | ||
await taskSelectBox.selectOption('RG') | ||
const modelSelectBox = await page.waitForSelector('#ml_selector .model_selection #mlDisp') | ||
await modelSelectBox.selectOption('loess') | ||
const methodMenu = await page.waitForSelector('#ml_selector #method_menu') | ||
const buttons = await methodMenu.waitForSelector('.buttons') | ||
|
||
const fit = await buttons.waitForSelector('input:nth-of-type(1)') | ||
await expect((await fit.getProperty('value')).jsonValue()).resolves.toBe('Fit') | ||
}, 10000) | ||
|
||
test('learn', async () => { | ||
const taskSelectBox = await page.waitForSelector('#ml_selector dl:first-child dd:nth-child(5) select') | ||
await taskSelectBox.selectOption('RG') | ||
const modelSelectBox = await page.waitForSelector('#ml_selector .model_selection #mlDisp') | ||
await modelSelectBox.selectOption('loess') | ||
const methodMenu = await page.waitForSelector('#ml_selector #method_menu') | ||
const buttons = await methodMenu.waitForSelector('.buttons') | ||
|
||
const methodFooter = await page.waitForSelector('#method_footer', { state: 'attached' }) | ||
await expect(methodFooter.evaluate(el => el.textContent)).resolves.toBe('') | ||
|
||
const initButton = await buttons.waitForSelector('input[value=Fit]') | ||
await initButton.evaluate(el => el.click()) | ||
|
||
await expect(methodFooter.evaluate(el => el.textContent)).resolves.toMatch(/^RMSE:[0-9.]+$/) | ||
}, 10000) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import { getPage } from '../helper/browser' | ||
|
||
describe('regression', () => { | ||
/** @type {Awaited<ReturnType<getPage>>} */ | ||
let page | ||
beforeEach(async () => { | ||
page = await getPage() | ||
}, 10000) | ||
|
||
afterEach(async () => { | ||
await page?.close() | ||
}) | ||
|
||
test('initialize', async () => { | ||
const taskSelectBox = await page.waitForSelector('#ml_selector dl:first-child dd:nth-child(5) select') | ||
await taskSelectBox.selectOption('RG') | ||
const modelSelectBox = await page.waitForSelector('#ml_selector .model_selection #mlDisp') | ||
await modelSelectBox.selectOption('lowess') | ||
const methodMenu = await page.waitForSelector('#ml_selector #method_menu') | ||
const buttons = await methodMenu.waitForSelector('.buttons') | ||
|
||
const fit = await buttons.waitForSelector('input:nth-of-type(1)') | ||
await expect((await fit.getProperty('value')).jsonValue()).resolves.toBe('Fit') | ||
}, 10000) | ||
|
||
test('learn', async () => { | ||
const taskSelectBox = await page.waitForSelector('#ml_selector dl:first-child dd:nth-child(5) select') | ||
await taskSelectBox.selectOption('RG') | ||
const modelSelectBox = await page.waitForSelector('#ml_selector .model_selection #mlDisp') | ||
await modelSelectBox.selectOption('lowess') | ||
const methodMenu = await page.waitForSelector('#ml_selector #method_menu') | ||
const buttons = await methodMenu.waitForSelector('.buttons') | ||
|
||
const methodFooter = await page.waitForSelector('#method_footer', { state: 'attached' }) | ||
await expect(methodFooter.evaluate(el => el.textContent)).resolves.toBe('') | ||
|
||
const initButton = await buttons.waitForSelector('input[value=Fit]') | ||
await initButton.evaluate(el => el.click()) | ||
|
||
await expect(methodFooter.evaluate(el => el.textContent)).resolves.toMatch(/^RMSE:[0-9.]+$/) | ||
}, 10000) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import { jest } from '@jest/globals' | ||
jest.retryTimes(3) | ||
|
||
import Matrix from '../../../lib/util/matrix.js' | ||
import LOESS from '../../../lib/model/loess.js' | ||
|
||
import { rmse } from '../../../lib/evaluate/regression.js' | ||
|
||
test('fit', () => { | ||
const model = new LOESS() | ||
const x = Matrix.random(50, 2, -2, 2).toArray() | ||
const t = [] | ||
for (let i = 0; i < x.length; i++) { | ||
t[i] = [x[i][0] + x[i][1] + (Math.random() - 0.5) / 20] | ||
} | ||
model.fit(x, t) | ||
const y = model.predict(x) | ||
const err = rmse(y, t)[0] | ||
expect(err).toBeLessThan(0.5) | ||
}) |