diff --git a/source/tf.js b/source/tf.js index 53a599ca18..2d3be60054 100644 --- a/source/tf.js +++ b/source/tf.js @@ -1618,74 +1618,13 @@ tf.TensorBundle.Table.Block = class { } }; -tf.BinaryReader = class { +tf.BinaryReader = class extends base.BinaryReader { constructor(buffer) { - this._buffer = buffer; - this._position = 0; - this._length = this._buffer.length; - this._dataView = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength); + super(buffer); this._decoder = new TextDecoder('utf-8'); } - get position() { - return this._position; - } - - get length() { - return this._length; - } - - seek(position) { - this._position = position >= 0 ? position : this._length + position; - if (this._position > this._length) { - throw new tf.Error(`Expected ${this._position - this._length} more bytes. The file might be corrupted. Unexpected end of file.`); - } - } - - skip(offset) { - this._position += offset; - if (this._position > this._length) { - throw new tf.Error(`Expected ${this._position - this._length} more bytes. The file might be corrupted. Unexpected end of file.`); - } - } - - read(size) { - const position = this._position; - this.skip(size); - return this._buffer.subarray(position, this._position); - } - - byte() { - const position = this._position; - this.skip(1); - return this._dataView.getUint8(position); - } - - uint16() { - const position = this._position; - this.skip(2); - return this._dataView.getUint16(position, true); - } - - int32() { - const position = this._position; - this.skip(4); - return this._dataView.getInt32(position, true); - } - - uint32() { - const position = this._position; - this.skip(4); - return this._dataView.getUint32(position, true); - } - - uint64() { - const position = this._position; - this.skip(4); - return this._dataView.getUint64(position, true); - } - string() { const size = this.uint32(); const buffer = this.read(size); @@ -2149,7 +2088,8 @@ tf.Context = class { } break; } - case 'int64': { + case 'int64': + case 'SymInt': { if (input.constant !== undefined && Number.isInteger(parseInt(input.constant))) { continue; } @@ -2209,7 +2149,8 @@ tf.Context = class { input.metadata = arg; break; } - case 'int64': { + case 'int64': + case 'SymInt': { const value = parseInt(input.constant); input.attr = new tf.proto.tensorflow.AttrValue(); input.attr.i = value;