-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e4d44fb
commit 14cf44d
Showing
20 changed files
with
1,251 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,16 @@ | ||
export { SGDOptimizer } from './optimizer/sgd.js' | ||
export { MomentumOptimizer } from './optimizer/momentum.js' | ||
export { AdaGradOptimizer } from './optimizer/adagrad.js' | ||
export { RMSPropOptimizer } from './optimizer/rmsprop.js' | ||
export { AdaDeltaOptimizer } from './optimizer/adadelta.js' | ||
export { AdamOptimizer } from './optimizer/adam.js' | ||
export { RMSPropGravesOptimizer } from './optimizer/rmspropgraves.js' | ||
export { SMORMS3Optimizer } from './optimizer/smorms3.js' | ||
export { AdaMaxOptimizer } from './optimizer/adamax.js' | ||
export { NadamOptimizer } from './optimizer/nadam.js' | ||
export { SGDOptimizer as sgd } from './optimizer/sgd.js' | ||
export { MomentumOptimizer as momentum } from './optimizer/momentum.js' | ||
export { AdaGradOptimizer as adagrad } from './optimizer/adagrad.js' | ||
export { RMSPropOptimizer as rmsprop } from './optimizer/rmsprop.js' | ||
export { AdaDeltaOptimizer as adadelta } from './optimizer/adadelta.js' | ||
export { AdamOptimizer as adam } from './optimizer/adam.js' | ||
export { RMSPropGravesOptimizer as rmspropgraves } from './optimizer/rmspropgraves.js' | ||
export { SMORMS3Optimizer as smorms3 } from './optimizer/smorms3.js' | ||
export { AdaMaxOptimizer as adamax } from './optimizer/adamax.js' | ||
export { NadamOptimizer as nadam } from './optimizer/nadam.js' | ||
export { SantaEOptimizer as santae } from './optimizer/santae.js' | ||
export { SantaSSSOptimizer as santasss } from './optimizer/santasss.js' | ||
export { AMSGradOptimizer as amsgrad } from './optimizer/amsgrad.js' | ||
export { AdaBoundOptimizer as adabound } from './optimizer/adabound.js' | ||
export { AMSBoundOptimizer as amsbound } from './optimizer/amsbound.js' | ||
export { AdaBeliefOptimizer as adabelief } from './optimizer/adabelief.js' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import Matrix from '../../../util/matrix.js' | ||
|
||
export class AdaBeliefOptimizer { | ||
constructor(lr = 0.001, beta1 = 0.9, beta2 = 0.999) { | ||
this._learningrate = lr | ||
this._beta1 = beta1 | ||
this._beta2 = beta2 | ||
} | ||
|
||
set learningRate(value) { | ||
this._learningrate = value | ||
} | ||
|
||
manager() { | ||
const this_ = this | ||
return { | ||
get lr() { | ||
return this_._learningrate | ||
}, | ||
params: {}, | ||
delta(key, value) { | ||
const valueIsNumber = typeof value === 'number' | ||
if (valueIsNumber) { | ||
value = new Matrix(1, 1, value) | ||
} | ||
if (!this.params[key]) { | ||
const z = value.copy() | ||
z.fill(0) | ||
this.params[key] = { m: z.copy(), v: z, t: 1 } | ||
} | ||
this.params[key].m.broadcastOperate(value, (a, b) => a * this_._beta1 + b * (1 - this_._beta1)) | ||
const mo = this.params[key].m.copy() | ||
mo.broadcastOperate(value, (a, b) => b - a) | ||
this.params[key].v.broadcastOperate(mo, (a, b) => a * this_._beta2 + (1 - this_._beta2) * b * b) | ||
const nv = 1 - this_._beta1 ** this.params[key].t | ||
const ns = 1 - this_._beta2 ** this.params[key].t | ||
const ret = this.params[key].m.copy() | ||
ret.broadcastOperate(this.params[key].v, (a, b) => (a / nv) * (this.lr / Math.sqrt(b / ns + 1.0e-12))) | ||
this.params[key].t++ | ||
return valueIsNumber ? ret.toScaler() : ret | ||
}, | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import Matrix from '../../../util/matrix.js' | ||
|
||
export class AdaBoundOptimizer { | ||
constructor(lr = 0.001, alpha = 0.003, beta1 = 0.9, beta2 = 0.999) { | ||
this._learningrate = lr | ||
this._alpha = alpha | ||
this._beta1 = beta1 | ||
this._beta2 = beta2 | ||
|
||
this._eta_lbound = t => this._learningrate * (1 - 1 / ((1 - beta2) * t + 1)) | ||
this._eta_ubound = t => this._learningrate * (1 + 1 / ((1 - beta2) * t + 1)) | ||
} | ||
|
||
set learningRate(value) { | ||
this._learningrate = value | ||
} | ||
|
||
manager() { | ||
const this_ = this | ||
return { | ||
get lr() { | ||
return this_._learningrate | ||
}, | ||
params: {}, | ||
delta(key, value) { | ||
const valueIsNumber = typeof value === 'number' | ||
if (valueIsNumber) { | ||
value = new Matrix(1, 1, value) | ||
} | ||
if (!this.params[key]) { | ||
const z = value.copy() | ||
z.fill(0) | ||
this.params[key] = { m: z.copy(), v: z, t: 1 } | ||
} | ||
this.params[key].m.broadcastOperate(value, (a, b) => a * this_._beta1 + b * (1 - this_._beta1)) | ||
this.params[key].v.broadcastOperate(value, (a, b) => a * this_._beta2 + (1 - this_._beta2) * b * b) | ||
const eta_lb = this_._eta_lbound(this.params[key].t) | ||
const eta_ub = this_._eta_ubound(this.params[key].t) | ||
const eta = this.params[key].v.copy() | ||
eta.map(v => Math.min(eta_ub, Math.max(eta_lb, this_._alpha / Math.sqrt(v)))) | ||
const ret = this.params[key].m.copy() | ||
ret.broadcastOperate(eta, (a, b) => (a * b) / Math.sqrt(this.params[key].t)) | ||
this.params[key].t++ | ||
return valueIsNumber ? ret.toScaler() : ret | ||
}, | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import Matrix from '../../../util/matrix.js' | ||
|
||
export class AMSBoundOptimizer { | ||
constructor(lr = 0.001, alpha = 0.003, beta1 = 0.9, beta2 = 0.999) { | ||
this._learningrate = lr | ||
this._alpha = alpha | ||
this._beta1 = beta1 | ||
this._beta2 = beta2 | ||
|
||
this._eta_lbound = t => this._learningrate * (1 - 1 / ((1 - beta2) * t + 1)) | ||
this._eta_ubound = t => this._learningrate * (1 + 1 / ((1 - beta2) * t + 1)) | ||
} | ||
|
||
set learningRate(value) { | ||
this._learningrate = value | ||
} | ||
|
||
manager() { | ||
const this_ = this | ||
return { | ||
get lr() { | ||
return this_._learningrate | ||
}, | ||
params: {}, | ||
delta(key, value) { | ||
const valueIsNumber = typeof value === 'number' | ||
if (valueIsNumber) { | ||
value = new Matrix(1, 1, value) | ||
} | ||
if (!this.params[key]) { | ||
const z = value.copy() | ||
z.fill(0) | ||
this.params[key] = { m: z.copy(), v: z.copy(), vh: z, t: 1 } | ||
} | ||
this.params[key].m.broadcastOperate(value, (a, b) => a * this_._beta1 + b * (1 - this_._beta1)) | ||
this.params[key].v.broadcastOperate(value, (a, b) => a * this_._beta2 + (1 - this_._beta2) * b * b) | ||
this.params[key].vh.broadcastOperate(this.params[key].v, (a, b) => Math.max(a, b)) | ||
const eta_lb = this_._eta_lbound(this.params[key].t) | ||
const eta_ub = this_._eta_ubound(this.params[key].t) | ||
const eta = this.params[key].vh.copy() | ||
eta.map(v => Math.min(eta_ub, Math.max(eta_lb, this_._alpha / Math.sqrt(v)))) | ||
const ret = this.params[key].m.copy() | ||
ret.broadcastOperate(eta, (a, b) => (a * b) / Math.sqrt(this.params[key].t)) | ||
this.params[key].t++ | ||
return valueIsNumber ? ret.toScaler() : ret | ||
}, | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import Matrix from '../../../util/matrix.js' | ||
|
||
export class AMSGradOptimizer { | ||
constructor(lr = 0.001, beta1 = 0.9, beta2 = 0.999) { | ||
this._learningrate = lr | ||
this._beta1 = beta1 | ||
this._beta2 = beta2 | ||
this._a = t => this._learningrate / Math.sqrt(t) | ||
} | ||
|
||
set learningRate(value) { | ||
this._learningrate = value | ||
} | ||
|
||
manager() { | ||
const this_ = this | ||
return { | ||
get lr() { | ||
return this_._learningrate | ||
}, | ||
params: {}, | ||
delta(key, value) { | ||
const valueIsNumber = typeof value === 'number' | ||
if (valueIsNumber) { | ||
value = new Matrix(1, 1, value) | ||
} | ||
if (!this.params[key]) { | ||
const z = value.copy() | ||
z.fill(0) | ||
this.params[key] = { m: z.copy(), v: z.copy(), vh: z, t: 1 } | ||
} | ||
this.params[key].m.broadcastOperate(value, (a, b) => a * this_._beta1 + b * (1 - this_._beta1)) | ||
this.params[key].v.broadcastOperate(value, (a, b) => a * this_._beta2 + b ** 2 * (1 - this_._beta2)) | ||
this.params[key].vh.broadcastOperate(this.params[key].v, (a, b) => Math.max(a, b)) | ||
const ret = this.params[key].m.copy() | ||
const lr = this_._a(this.params[key].t) | ||
ret.broadcastOperate(this.params[key].vh, (a, b) => (lr * a) / Math.sqrt(b + 1.0e-12)) | ||
this.params[key].t++ | ||
return valueIsNumber ? ret.toScaler() : ret | ||
}, | ||
} | ||
} | ||
} |
Oops, something went wrong.