diff --git a/js/data/dashboard_estat.js b/js/data/dashboard_estat.js index 4702d014..6cc35382 100644 --- a/js/data/dashboard_estat.js +++ b/js/data/dashboard_estat.js @@ -62,9 +62,6 @@ export default class EStatData extends FixData { constructor(manager) { super(manager) this._name = 'Nikkei Indexes' - this._shift = [] - this._scale = [] - this._scaled = true this._lastRequested = 0 const elm = this.setting.data.configElement @@ -118,16 +115,6 @@ export default class EStatData extends FixData { this._loader = document.createElement('div') elm.appendChild(this._loader) - const optionalElm = document.createElement('div') - const scaledCheckbox = document.createElement('input') - scaledCheckbox.type = 'checkbox' - scaledCheckbox.checked = true - scaledCheckbox.onchange = () => { - this._scaled = scaledCheckbox.checked - this._readyData() - } - optionalElm.append('Scale', scaledCheckbox) - elm.appendChild(optionalElm) this._initIndicatorSelector().then(() => { const info = presetInfos[this._name] if (info) { @@ -160,15 +147,7 @@ export default class EStatData extends FixData { } get x() { - if (!this._scaled) return this.originalX - if (this._requireDateInput) { - return this._datetime.map(v => [v]) - } - this._readyScaledData() - return this._x.map(v => { - const c = v.map((a, d) => (a - this._shift[d]) / this._scale[d]) - return this._selector.object.map(i => c[i]) - }) + return this.originalX } get originalY() { @@ -180,13 +159,7 @@ export default class EStatData extends FixData { } get y() { - if (!this._scaled) return this.originalY - this._readyScaledData() - const target = this._selector.target - if (target >= 0) { - return this._x.map(v => (v[target] - this._shift[target]) / this._scale[target]) - } - return Array(this._x.length).fill(0) + return this.originalY } get params() { @@ -404,33 +377,6 @@ export default class EStatData extends FixData { } } - _readyScaledData() { - if (this._scale.length > 0) { - return - } - this._shift = [] - this._scale = [] - if (this._x.length > 0) { - const min = Array(this._x[0].length).fill(Infinity) - const max = Array(this._x[0].length).fill(-Infinity) - for (let i = 0; i < this._x.length; i++) { - for (let d = 0; d < this._x[i].length; d++) { - min[d] = Math.min(min[d], this._x[i][d]) - max[d] = Math.max(max[d], this._x[i][d]) - } - } - const rmax = 10 - const rmin = 0 - for (let d = 0; d < min.length; d++) { - if (min[d] === max[d]) { - max[d] = min[d] + 1 - } - this._scale[d] = (max[d] - min[d]) / (rmax - rmin) - this._shift[d] = min[d] - rmin * this._scale[d] - } - } - } - async _getMeta(func) { const keySuffix = func.toUpperCase() === 'INDICATOR' || func.toUpperCase() === 'REGION' ? 'INF' : 'INFO' const metaKey = `GET_META_${func.toUpperCase()}_${keySuffix}` @@ -530,8 +476,6 @@ export default class EStatData extends FixData { async _readyData() { this._x = [] - this._shift = [] - this._scale = [] this._index = null this._datetime = null this._manager.platform?.init() diff --git a/js/data/eurostat.js b/js/data/eurostat.js index a7c79ace..df6d2d80 100644 --- a/js/data/eurostat.js +++ b/js/data/eurostat.js @@ -30,9 +30,6 @@ export default class EurostatData extends FixData { constructor(manager) { super(manager) this._name = 'Population and employment' - this._shift = [] - this._scale = [] - this._scaled = true this._filterItems = null this._lastRequested = 0 @@ -88,17 +85,6 @@ export default class EurostatData extends FixData { this._loader = document.createElement('div') elm.appendChild(this._loader) - const optionalElm = document.createElement('div') - const scaledCheckbox = document.createElement('input') - scaledCheckbox.type = 'checkbox' - scaledCheckbox.checked = true - scaledCheckbox.onchange = () => { - this._scaled = scaledCheckbox.checked - this._readyData() - } - optionalElm.append('Scale', scaledCheckbox) - elm.appendChild(optionalElm) - this._readyData() } @@ -125,15 +111,7 @@ export default class EurostatData extends FixData { } get x() { - if (!this._scaled) return this.originalX - if (this._requireDateInput) { - return this._datetime.map(v => [v]) - } - this._readyScaledData() - return this._x.map(v => { - const c = v.map((a, d) => (a - this._shift[d]) / this._scale[d]) - return this._selector.object.map(i => c[i]) - }) + return this.originalX } get originalY() { @@ -144,13 +122,7 @@ export default class EurostatData extends FixData { } get y() { - if (!this._scaled) return this.originalY - this._readyScaledData() - const target = this._selector.target - if (target >= 0) { - return this._x.map(v => (v[target] - this._shift[target]) / this._scale[target]) - } - return Array(this._x.length).fill(0) + return this.originalY } get params() { @@ -166,33 +138,6 @@ export default class EurostatData extends FixData { } } - _readyScaledData() { - if (this._scale.length > 0) { - return - } - this._shift = [] - this._scale = [] - if (this._x.length > 0) { - const min = Array(this._x[0].length).fill(Infinity) - const max = Array(this._x[0].length).fill(-Infinity) - for (let i = 0; i < this._x.length; i++) { - for (let d = 0; d < this._x[i].length; d++) { - min[d] = Math.min(min[d], this._x[i][d]) - max[d] = Math.max(max[d], this._x[i][d]) - } - } - const rmax = 10 - const rmin = 0 - for (let d = 0; d < min.length; d++) { - if (min[d] === max[d]) { - max[d] = min[d] + 1 - } - this._scale[d] = (max[d] - min[d]) / (rmax - rmin) - this._shift[d] = min[d] - rmin * this._scale[d] - } - } - } - async _getData(datasetCode, query) { const params = { format: 'JSON', @@ -239,8 +184,6 @@ export default class EurostatData extends FixData { async _readyData() { this._x = [] - this._shift = [] - this._scale = [] this._index = null this._datetime = null this._manager.platform?.init() diff --git a/js/model_selector.js b/js/model_selector.js index 3027c3a4..1f6c4585 100644 --- a/js/model_selector.js +++ b/js/model_selector.js @@ -62,7 +62,11 @@ const AITask = { const AIPreprocess = { function: { title: 'Basis function', - tasks: ['CF', 'RG', 'RL'], + tasks: ['CF', 'SC', 'RG', 'IN', 'RL', 'AD', 'DR', 'CP'], + }, + transform: { + title: 'Transformers', + tasks: ['CT', 'CF', 'SC', 'RG', 'IN', 'RL', 'AD', 'DR', 'FS', 'SM', 'TP', 'CP'], }, } for (const ap of Object.keys(AIPreprocess)) { diff --git a/js/platform/base.js b/js/platform/base.js index 953c7d6c..62ee7b77 100644 --- a/js/platform/base.js +++ b/js/platform/base.js @@ -149,30 +149,6 @@ export class DefaultPlatform extends BasePlatform { this._renderer[0].testResult(pred) } - evaluate(cb) { - if (this.task !== 'CF' && this.task !== 'RG' && this.task !== 'RL') { - return - } - cb(this.datas.x, p => { - const t = this.datas.y - if (this.task === 'CF' || this.task === 'RL') { - let acc = 0 - for (let i = 0; i < t.length; i++) { - if (t[i] === p[i]) { - acc++ - } - } - this._getEvaluateElm().innerText = 'Accuracy:' + acc / t.length - } else if (this.task === 'RG') { - let rmse = 0 - for (let i = 0; i < t.length; i++) { - rmse += (t[i] - p[i]) ** 2 - } - this._getEvaluateElm().innerText = 'RMSE:' + Math.sqrt(rmse / t.length) - } - }) - } - init() { this._cur_dimension = this.setting.dimension this.setting.footer.innerText = '' @@ -189,10 +165,24 @@ export class DefaultPlatform extends BasePlatform { } } + invertScale(x) { + for (const preprocess of this._manager.preprocesses) { + if (preprocess.inverse) { + if (Array.isArray(x[0])) { + x = preprocess.inverse(x) + } else { + x = preprocess.inverse([x])[0] + } + } + } + return x + } + centroids(center, cls, { line = false, duration = 0 } = {}) { if (!this._centroids) { this._centroids = new CentroidPlotter(this._renderer[0]) } + center = this.invertScale(center) this._centroids.set(center, cls, { line, duration }) } diff --git a/js/platform/series.js b/js/platform/series.js index e2736756..e7dfe347 100644 --- a/js/platform/series.js +++ b/js/platform/series.js @@ -61,6 +61,19 @@ export default class SeriesPlatform extends BasePlatform { } } + invertScale(x) { + for (const preprocess of this._manager.preprocesses) { + if (preprocess.inverse) { + if (Array.isArray(x[0])) { + x = preprocess.inverse(x) + } else { + x = preprocess.inverse([x])[0] + } + } + } + return x + } + resetPredicts() { this._renderer.forEach(rend => rend.resetPredicts()) } diff --git a/js/preprocess/transform.js b/js/preprocess/transform.js new file mode 100644 index 00000000..07450c6c --- /dev/null +++ b/js/preprocess/transform.js @@ -0,0 +1,63 @@ +import MinmaxNormalization from '../../lib/model/minmax.js' +import Standardization from '../../lib/model/standardization.js' +import MaxAbsScaler from '../../lib/model/maxabs.js' +import RobustScaler from '../../lib/model/robust_scaler.js' +import BoxCox from '../../lib/model/box_cox.js' +import YeoJohnson from '../../lib/model/yeo_johnson.js' + +const transformers = { + minmax: MinmaxNormalization, + standard: Standardization, + maxabs: MaxAbsScaler, + robust: RobustScaler, + 'Box-Cox': BoxCox, + 'Yeo-Johnson': YeoJohnson, +} + +export default class TransformPreprocessor { + constructor(manager) { + this._manager = manager + this._method = 'standard' + + this.init() + } + + init() { + if (!this._r) { + const elm = this._manager.setting.preprocess.configElement + this._r = document.createElement('div') + elm.append(this._r) + } else { + this._r.replaceChildren() + } + const methodElm = document.createElement('div') + const method = document.createElement('select') + for (const key of Object.keys(transformers)) { + const opt = method.appendChild(document.createElement('option')) + opt.value = opt.innerText = key + } + method.onchange = () => { + this._method = method.value + this._manager.setting.ml.refresh() + } + method.value = this._method + methodElm.append('Method ', method) + this._r.append(methodElm) + } + + apply(x, { dofit = true }) { + if (dofit) { + this._model = new transformers[this._method]() + this._model.fit(x) + } + return this._model.predict(x) + } + + inverse(z) { + return this._model.inverse(z) + } + + terminate() { + this._r?.remove() + } +} diff --git a/js/renderer/line.js b/js/renderer/line.js index b63ed662..f77e55a1 100644 --- a/js/renderer/line.js +++ b/js/renderer/line.js @@ -422,10 +422,15 @@ export default class LineRenderer extends BaseRenderer { const datas = this.datas const path = [] if (datas.length > 0) { - path.push(this.toPoint([datas.length - 1, datas.x[datas.length - 1] || [datas.y[datas.length - 1]]])) + path.push( + this.toPoint([ + datas.length - 1, + datas.dimension > 0 ? datas.x[datas.length - 1] : [datas.y[datas.length - 1]], + ]) + ) } for (let i = 0; i < pred.length; i++) { - const a = this.toPoint([i + datas.length, pred[i]]) + const a = this.toPoint([i + datas.length, this._manager.platform.invertScale(pred[i])]) const p = new DataPoint(this._r_tile, a, specialCategory.dummy) path.push(a) this._pred_points.push(p) @@ -439,7 +444,7 @@ export default class LineRenderer extends BaseRenderer { } else if (task === 'SM') { const path = [] for (let i = 0; i < pred.length; i++) { - const a = this.toPoint([i, pred[i]]) + const a = this.toPoint([i, this._manager.platform.invertScale(pred[i])]) path.push(a) } if (path.length > 0) { diff --git a/js/view/dbscan.js b/js/view/dbscan.js index de52153c..90991a97 100644 --- a/js/view/dbscan.js +++ b/js/view/dbscan.js @@ -24,34 +24,39 @@ export default function (platform) { const scale = platform._renderer[0].scale[0] const datas = platform.trainInput - if (metric.value === 'euclid') { - for (let i = 0; i < datas.length; i++) { - const circle = document.createElementNS('http://www.w3.org/2000/svg', 'circle') - circle.setAttribute('cx', datas[i][0] * scale) - circle.setAttribute('cy', datas[i][1] * scale) - circle.setAttribute('r', eps.value * scale) - circle.setAttribute('fill-opacity', 0) - circle.setAttribute('stroke', getCategoryColor(pred[i] + 1)) - range.append(circle) - } - } else if (metric.value === 'manhattan') { - for (let i = 0; i < datas.length; i++) { + const invscale = platform.invertScale([ + Array(platform.datas.dimension).fill(1), + Array(platform.datas.dimension).fill(2), + ]) + const s0 = invscale[1][platform._renderer[0]._select[0]] - invscale[0][platform._renderer[0]._select[0]] + const s1 = invscale[1][platform._renderer[0]._select[1]] - invscale[0][platform._renderer[0]._select[1]] + for (let i = 0; i < datas.length; i++) { + const p = platform._renderer[0].toPoint(platform.invertScale(datas[i])) + if (metric.value === 'euclid') { + const ellipse = document.createElementNS('http://www.w3.org/2000/svg', 'ellipse') + ellipse.setAttribute('cx', p[0]) + ellipse.setAttribute('cy', p[1]) + ellipse.setAttribute('rx', eps.value * scale * s0) + ellipse.setAttribute('ry', eps.value * scale * s1) + ellipse.setAttribute('fill-opacity', 0) + ellipse.setAttribute('stroke', getCategoryColor(pred[i] + 1)) + range.append(ellipse) + } else if (metric.value === 'manhattan') { const polygon = document.createElementNS('http://www.w3.org/2000/svg', 'polygon') - const x0 = datas[i][0] * scale - const y0 = datas[i][1] * scale const d = eps.value * scale - polygon.setAttribute('points', `${x0 - d},${y0} ${x0},${y0 - d} ${x0 + d},${y0} ${x0},${y0 + d}`) + polygon.setAttribute( + 'points', + `${p[0] - d * s0},${p[1]} ${p[0]},${p[1] - d * s1} ${p[0] + d * s0},${p[1]} ${p[0]},${p[1] + d * s1}` + ) polygon.setAttribute('fill-opacity', 0) polygon.setAttribute('stroke', getCategoryColor(pred[i] + 1)) range.append(polygon) - } - } else if (metric.value === 'chebyshev') { - for (let i = 0; i < datas.length; i++) { + } else if (metric.value === 'chebyshev') { const rect = document.createElementNS('http://www.w3.org/2000/svg', 'rect') - rect.setAttribute('x', (datas[i][0] - eps.value) * scale) - rect.setAttribute('y', (datas[i][1] - eps.value) * scale) - rect.setAttribute('width', eps.value * 2 * scale) - rect.setAttribute('height', eps.value * 2 * scale) + rect.setAttribute('x', p[0] - eps.value * scale * s0) + rect.setAttribute('y', p[1] - eps.value * scale * s1) + rect.setAttribute('width', eps.value * 2 * scale * s0) + rect.setAttribute('height', eps.value * 2 * scale * s1) rect.setAttribute('fill-opacity', 0) rect.setAttribute('stroke', getCategoryColor(pred[i] + 1)) range.append(rect) diff --git a/js/view/decision_tree.js b/js/view/decision_tree.js index 6186dc20..cc84ce42 100644 --- a/js/view/decision_tree.js +++ b/js/view/decision_tree.js @@ -2,93 +2,6 @@ import Matrix from '../../lib/util/matrix.js' import { DecisionTreeClassifier, DecisionTreeRegression } from '../../lib/model/decision_tree.js' import Controller from '../controller.js' -import { getCategoryColor } from '../utils.js' - -class DecisionTreePlotter { - constructor(platform) { - this._platform = platform - this._mode = platform.task - this._svg = platform.svg - this._r = null - this._lineEdge = [] - } - - remove() { - this._svg.querySelector('.separation')?.remove() - } - - plot(tree) { - this._svg.querySelector('.separation')?.remove() - if (this._platform.datas.length === 0) { - return - } - this._r = document.createElementNS('http://www.w3.org/2000/svg', 'g') - this._r.classList.add('separation') - if (this._platform.datas.dimension === 1) { - this._svg.append(this._r) - } else { - this._svg.insertBefore(this._r, this._svg.firstChild) - this._r.setAttribute('opacity', 0.5) - } - this._lineEdge = [] - this._dispRange(tree._tree) - if (this._platform.datas.dimension === 1) { - const line = p => { - let s = '' - for (let i = 0; i < p.length; i++) { - s += `${i === 0 ? 'M' : 'L'}${p[i][0]},${p[i][1]}` - } - return s - } - const path = document.createElementNS('http://www.w3.org/2000/svg', 'path') - path.setAttribute('stroke', 'red') - path.setAttribute('fill-opacity', 0) - path.setAttribute('d', line(this._lineEdge)) - this._r.append(path) - } - } - - _dispRange(root, r) { - r = r || this._platform.datas.domain - if (root.children.length === 0) { - let max_cls = 0, - max_v = 0 - if (this._mode === 'CF') { - root.value.forEach((v, k) => { - if (v > max_v) { - max_v = v - max_cls = k - } - }) - } else { - max_cls = root.value - } - if (this._platform.datas.dimension === 1) { - const p1 = this._platform._renderer[0].toPoint([r[0][0], max_cls]) - const p2 = this._platform._renderer[0].toPoint([r[0][1], max_cls]) - this._lineEdge.push(p1) - this._lineEdge.push(p2) - } else { - const p1 = this._platform._renderer[0].toPoint([r[0][0], r[1][0]]) - const p2 = this._platform._renderer[0].toPoint([r[0][1], r[1][1]]) - const rect = document.createElementNS('http://www.w3.org/2000/svg', 'rect') - rect.setAttribute('x', p1[0]) - rect.setAttribute('y', p1[1]) - rect.setAttribute('width', p2[0] - p1[0]) - rect.setAttribute('height', p2[1] - p1[1]) - rect.setAttribute('fill', getCategoryColor(max_cls)) - this._r.append(rect) - } - } else { - root.children.forEach((n, i) => { - let r0 = [[].concat(r[0]), [].concat(r[1])] - let mm = i === 0 ? 1 : 0 - r0[root.feature][mm] = root.threshold - this._dispRange(n, r0) - }) - } - } -} export default function (platform) { platform.setting.ml.usage = 'Click and add data point. Next, click "Initialize". Finally, click "Separate".' @@ -98,7 +11,6 @@ export default function (platform) { } const controller = new Controller(platform) const mode = platform.task - const plotter = new DecisionTreePlotter(platform) let tree = null const dispRange = function () { @@ -110,14 +22,12 @@ export default function (platform) { const x = Matrix.fromArray(platform.trainInput) platform.trainResult = x.col(idx).toArray() } else if (platform.datas.dimension <= 2) { - plotter.plot(tree) + let pred = tree.predict(platform.testInput(platform.datas.dimension === 1 ? 0.1 : 1)) + platform.testResult(pred) } else { let pred = tree.predict(platform.testInput(2)) platform.testResult(pred) } - platform.evaluate((x, e_cb) => { - e_cb(tree.predict(x)) - }) } const methods = mode === 'CF' ? ['CART', 'ID3'] : ['CART'] @@ -148,8 +58,4 @@ export default function (platform) { }) const depth = controller.text('0') controller.text(' depth ') - - return () => { - plotter.remove() - } } diff --git a/js/view/gmm.js b/js/view/gmm.js index 922cb72b..d67d3a01 100644 --- a/js/view/gmm.js +++ b/js/view/gmm.js @@ -6,7 +6,8 @@ import { specialCategory, getCategoryColor } from '../utils.js' class GMMPlotter { // see http://d.hatena.ne.jp/natsutan/20110421/1303344155 - constructor(svg, model, grayscale = false) { + constructor(platform, svg, model, grayscale = false) { + this._platform = platform this._r = document.createElementNS('http://www.w3.org/2000/svg', 'g') svg.append(this._r) this._model = model @@ -14,6 +15,7 @@ class GMMPlotter { this._circle = [] this._grayscale = grayscale this._duration = 200 + this._scale = platform._renderer[0].scale?.[0] ?? 0 } terminate() { @@ -24,19 +26,31 @@ class GMMPlotter { if (!this._model._m[i]) { return } - const cn = this._model._m[i].value + const cn = this._platform.invertScale(this._model._m[i].value) const s = this._model._s[i].value const su2 = (s[0] + s[3] + Math.sqrt((s[0] - s[3]) ** 2 + 4 * s[1] ** 2)) / 2 const sv2 = (s[0] + s[3] - Math.sqrt((s[0] - s[3]) ** 2 + 4 * s[1] ** 2)) / 2 const c = 2.146 - let t = (360 * Math.atan((su2 - s[0]) / s[1])) / (2 * Math.PI) - if (isNaN(t)) { - t = 0 + let rad = Math.atan((su2 - s[0]) / s[1]) + if (isNaN(rad)) { + rad = 0 } + const invscale = this._platform.invertScale([ + Array(this._platform.datas.dimension).fill(1), + Array(this._platform.datas.dimension).fill(2), + ]) - ell.setAttribute('rx', c * Math.sqrt(su2) * 1000) - ell.setAttribute('ry', c * Math.sqrt(sv2) * 1000) - ell.setAttribute('transform', 'translate(' + cn[0] * 1000 + ',' + cn[1] * 1000 + ') ' + 'rotate(' + t + ')') + ell.setAttribute('rx', c * Math.sqrt(su2) * this._scale) + ell.setAttribute('ry', c * Math.sqrt(sv2) * this._scale) + ell.setAttribute('vector-effect', 'non-scaling-stroke') + const s0 = + invscale[1][this._platform._renderer[0]._select[0]] - invscale[0][this._platform._renderer[0]._select[0]] + const s1 = + invscale[1][this._platform._renderer[0]._select[1]] - invscale[0][this._platform._renderer[0]._select[1]] + ell.setAttribute( + 'transform', + `matrix(${Math.cos(rad) * s0} ${Math.sin(rad) * s1} ${-Math.sin(rad) * s0} ${Math.cos(rad) * s1} ${cn[0] * this._scale} ${cn[1] * this._scale})` + ) } add(category) { @@ -84,7 +98,7 @@ export default function (platform) { } else if (mode === 'RG') { model = new GMR() } - const plotter = new GMMPlotter(svg, model, grayscale) + const plotter = new GMMPlotter(platform, svg, model, grayscale) const fitModel = (doFit, cb) => { if (mode === 'AD') { if (doFit) model.fit(platform.trainInput) diff --git a/js/view/mean_shift.js b/js/view/mean_shift.js index 6425474b..832da488 100644 --- a/js/view/mean_shift.js +++ b/js/view/mean_shift.js @@ -21,13 +21,13 @@ export default function (platform) { let model = null const plot = () => { - const scale = platform._renderer[0].scale?.[0] ?? 0 const pred = model.predict(threshold.value) platform.trainResult = pred.map(v => v + 1) for (let i = 0; i < c.length; i++) { + const centroid = platform._renderer[0].toPoint(platform.invertScale(model._centroids[i])) c[i].setAttribute('stroke', getCategoryColor(pred[i] + 1)) - c[i].setAttribute('cx', model._centroids[i][0] * scale) - c[i].setAttribute('cy', model._centroids[i][1] * scale) + c[i].setAttribute('cx', centroid[0]) + c[i].setAttribute('cy', centroid[1]) } } @@ -44,17 +44,22 @@ export default function (platform) { } model.init(tx) if (platform.task !== 'SG' && scale > 0) { + const invscale = platform.invertScale([ + Array(platform.datas.dimension).fill(1), + Array(platform.datas.dimension).fill(2), + ]) + const s0 = invscale[1][platform._renderer[0]._select[0]] - invscale[0][platform._renderer[0]._select[0]] + const s1 = invscale[1][platform._renderer[0]._select[1]] - invscale[0][platform._renderer[0]._select[1]] c.forEach(c => c.remove()) - c = platform._renderer[0].points.map(p => { - const circle = document.createElementNS('http://www.w3.org/2000/svg', 'circle') - circle.setAttribute('cx', p.at[0] * scale) - circle.setAttribute('cy', p.at[1] * scale) - circle.setAttribute('r', model.h * scale) - circle.setAttribute('stroke', 'black') - circle.setAttribute('fill-opacity', 0) - circle.setAttribute('stroke-opacity', 0.5) - csvg.append(circle) - return circle + c = platform._renderer[0].points.map(() => { + const ellipse = document.createElementNS('http://www.w3.org/2000/svg', 'ellipse') + ellipse.setAttribute('rx', model.h * scale * s0) + ellipse.setAttribute('ry', model.h * scale * s1) + ellipse.setAttribute('stroke', 'black') + ellipse.setAttribute('fill-opacity', 0) + ellipse.setAttribute('stroke-opacity', 0.5) + csvg.append(ellipse) + return ellipse }) } plot() diff --git a/js/view/vbgmm.js b/js/view/vbgmm.js index e59c759f..56c45a2a 100644 --- a/js/view/vbgmm.js +++ b/js/view/vbgmm.js @@ -3,7 +3,8 @@ import Controller from '../controller.js' import { getCategoryColor } from '../utils.js' class VBGMMPlotter { - constructor(svg, model) { + constructor(platform, svg, model) { + this._platform = platform this._r = document.createElementNS('http://www.w3.org/2000/svg', 'g') svg.append(this._r) this._model = model @@ -11,7 +12,7 @@ class VBGMMPlotter { this._circle = [] this._rm = [] this._duration = 200 - this._scale = 1000 + this._scale = platform._renderer[0].scale?.[0] ?? 0 for (let i = 0; i < this._size; i++) { this.add(i + 1) @@ -37,21 +38,30 @@ class VBGMMPlotter { } _set_el_attr(ell, i) { - let cn = this._model.means.row(i).value + let cn = this._platform.invertScale(this._model.means.row(i).value) let s = this._model.covs[i].value const su2 = (s[0] + s[3] + Math.sqrt((s[0] - s[3]) ** 2 + 4 * s[1] ** 2)) / 2 const sv2 = (s[0] + s[3] - Math.sqrt((s[0] - s[3]) ** 2 + 4 * s[1] ** 2)) / 2 const c = 2.146 - let t = (360 * Math.atan((su2 - s[0]) / s[1])) / (2 * Math.PI) - if (isNaN(t)) { - t = 0 + let rad = Math.atan((su2 - s[0]) / s[1]) + if (isNaN(rad)) { + rad = 0 } + const invscale = this._platform.invertScale([ + Array(this._platform.datas.dimension).fill(1), + Array(this._platform.datas.dimension).fill(2), + ]) ell.setAttribute('rx', c * Math.sqrt(su2) * this._scale) ell.setAttribute('ry', c * Math.sqrt(sv2) * this._scale) + ell.setAttribute('vector-effect', 'non-scaling-stroke') + const s0 = + invscale[1][this._platform._renderer[0]._select[0]] - invscale[0][this._platform._renderer[0]._select[0]] + const s1 = + invscale[1][this._platform._renderer[0]._select[1]] - invscale[0][this._platform._renderer[0]._select[1]] ell.setAttribute( 'transform', - 'translate(' + cn[0] * this._scale + ',' + cn[1] * this._scale + ') ' + 'rotate(' + t + ')' + `matrix(${Math.cos(rad) * s0} ${Math.sin(rad) * s1} ${-Math.sin(rad) * s0} ${Math.cos(rad) * s1} ${cn[0] * this._scale} ${cn[1] * this._scale})` ) } @@ -87,7 +97,7 @@ export default function (platform) { platform.trainResult = pred.map(v => v + 1) clusters.value = model.effectivity.reduce((s, v) => s + (v ? 1 : 0), 0) if (!plotter) { - plotter = new VBGMMPlotter(platform.svg, model) + plotter = new VBGMMPlotter(platform, platform.svg, model) } plotter.move() const effectivity = model.effectivity