Skip to content

Commit

Permalink
Accept object for add method of ComputationalGraph class (#902)
Browse files Browse the repository at this point in the history
  • Loading branch information
ishii-norimi authored Dec 1, 2024
1 parent 908e892 commit 200c470
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
7 changes: 2 additions & 5 deletions lib/model/neuralnetwork.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import Matrix from '../util/matrix.js'
import Tensor from '../util/tensor.js'
import Layer from './nns/layer/base.js'
export { default as Layer } from './nns/layer/base.js'

import ComputationalGraph from './nns/graph.js'
Expand Down Expand Up @@ -61,13 +60,11 @@ export default class NeuralNetwork {
}
const graph = new ComputationalGraph()
for (const cn of const_numbers) {
const cl = Layer.fromObject({ type: 'const', value: [[cn]] })
graph.add(cl, `__const_number_${cn}`, [])
graph.add({ type: 'const', value: [[cn]] }, `__const_number_${cn}`, [])
}

for (const l of layers) {
const cl = Layer.fromObject(l)
graph.add(cl, l.name, l.input)
graph.add(l, l.name, l.input)
}

return new NeuralNetwork(graph, optimizer)
Expand Down
11 changes: 9 additions & 2 deletions lib/model/nns/graph.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ import { InputLayer, OutputLayer } from './layer/index.js'
import ONNXImporter from './onnx/onnx_importer.js'

/**
* @typedef {import("./layer/index").PlainLayerObject & {input?: string | string[], name?: string}} LayerObject
* @ignore
* @typedef {import("./layer/index").PlainLayerObject} PlainLayerObject
*/
/**
* @typedef {PlainLayerObject & {input?: string | string[], name?: string}} LayerObject
* @typedef {object} Node
* @property {Layer} layer Layer
* @property {string} name Name of the node
Expand Down Expand Up @@ -127,11 +131,14 @@ export default class ComputationalGraph {

/**
* Add a layer.
* @param {Layer} layer Added layer
* @param {Layer | PlainLayerObject} layer Added layer
* @param {string} [name] Node name
* @param {string[] | string} [inputs] Input node names for the added layer
*/
add(layer, name, inputs = undefined) {
if (!(layer instanceof Layer)) {
layer = Layer.fromObject(layer)
}
let parentinfos = []
if (!inputs) {
if (this._nodes.length > 0) {
Expand Down
10 changes: 10 additions & 0 deletions tests/lib/model/nns/graph.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,16 @@ describe('Computational Graph', () => {
expect(graph.nodes[1].parents[0].subscript).toBeNull()
})

test('object', () => {
const graph = new ComputationalGraph()
graph.add({ type: 'input' })
graph.add({ type: 'tanh' })

expect(graph.nodes[1].parents).toHaveLength(1)
expect(graph.nodes[1].parents[0].index).toBe(0)
expect(graph.nodes[1].parents[0].subscript).toBeNull()
})

test('string input', () => {
const graph = new ComputationalGraph()
graph.add(Layer.fromObject({ type: 'input' }), 'in')
Expand Down

0 comments on commit 200c470

Please sign in to comment.