diff --git a/tfjs-converter/docs/supported_ops.md b/tfjs-converter/docs/supported_ops.md index acf066705d8..96dc049094e 100644 --- a/tfjs-converter/docs/supported_ops.md +++ b/tfjs-converter/docs/supported_ops.md @@ -268,6 +268,13 @@ |Not mapped|ifft| |Not mapped|rfft| +## Operations - Strings + +|Tensorflow Op Name|Tensorflow.js Op Name| +|---|---| +|DecodeBase64|decodeBase64| +|EncodeBase64|encodeBase64| + ## Tensors - Transformations |Tensorflow Op Name|Tensorflow.js Op Name| diff --git a/tfjs-converter/python/tensorflowjs/op_list/string.json b/tfjs-converter/python/tensorflowjs/op_list/string.json new file mode 100644 index 00000000000..f7a2dfca6a6 --- /dev/null +++ b/tfjs-converter/python/tensorflowjs/op_list/string.json @@ -0,0 +1,31 @@ +[ + { + "tfOpName": "DecodeBase64", + "category": "string", + "inputs": [ + { + "start": 0, + "name": "input", + "type": "tensor" + } + ] + }, + { + "tfOpName": "EncodeBase64", + "category": "string", + "inputs": [ + { + "start": 0, + "name": "input", + "type": "tensor" + } + ], + "attrs": [ + { + "tfName": "pad", + "name": "pad", + "type": "bool" + } + ] + } +] \ No newline at end of file diff --git a/tfjs-converter/src/operations/executors/string_executor.ts b/tfjs-converter/src/operations/executors/string_executor.ts new file mode 100644 index 00000000000..2ae03e1f778 --- /dev/null +++ b/tfjs-converter/src/operations/executors/string_executor.ts @@ -0,0 +1,46 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tfc from '@tensorflow/tfjs-core'; + +import {NamedTensorsMap} from '../../data/types'; +import {ExecutionContext} from '../../executor/execution_context'; +import {InternalOpExecutor, Node} from '../types'; + +import {getParamValue} from './utils'; + +export let executeOp: InternalOpExecutor = + (node: Node, tensorMap: NamedTensorsMap, + context: ExecutionContext): tfc.Tensor[] => { + switch (node.op) { + case 'DecodeBase64': { + const input = + getParamValue('str', node, tensorMap, context) as tfc.Tensor; + return [tfc.decodeBase64(input)]; + } + case 'EncodeBase64': { + const input = + getParamValue('str', node, tensorMap, context) as tfc.Tensor; + const pad = getParamValue('pad', node, tensorMap, context) as boolean; + return [tfc.encodeBase64(input, pad)]; + } + default: + throw TypeError(`Node type ${node.op} is not implemented`); + } + }; + +export const CATEGORY = 'string'; diff --git a/tfjs-converter/src/operations/executors/string_executor_test.ts b/tfjs-converter/src/operations/executors/string_executor_test.ts new file mode 100644 index 00000000000..ff0ae48082b --- /dev/null +++ b/tfjs-converter/src/operations/executors/string_executor_test.ts @@ -0,0 +1,62 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import * as tfc from '@tensorflow/tfjs-core'; + +import {ExecutionContext} from '../../executor/execution_context'; +import {Node} from '../types'; + +import {executeOp} from './string_executor'; +import {createBoolAttr, createTensorAttr} from './test_helper'; + +describe('string', () => { + let node: Node; + const input1 = [tfc.tensor(['a'], [1], 'string')]; + const context = new ExecutionContext({}, {}); + + beforeEach(() => { + node = { + name: 'test', + op: '', + category: 'string', + inputNames: ['input1'], + inputs: [], + inputParams: {str: createTensorAttr(0)}, + attrParams: {}, + children: [] + }; + }); + + describe('executeOp', () => { + describe('DecodeBase64', () => { + it('should call tfc.decodeBase64', () => { + spyOn(tfc, 'decodeBase64'); + node.op = 'DecodeBase64'; + executeOp(node, {input1}, context); + expect(tfc.decodeBase64).toHaveBeenCalledWith(input1[0]); + }); + }); + describe('EncodeBase64', () => { + it('should call tfc.encodeBase64', () => { + spyOn(tfc, 'encodeBase64'); + node.op = 'EncodeBase64'; + node.attrParams.pad = createBoolAttr(true); + executeOp(node, {input1}, context); + expect(tfc.encodeBase64).toHaveBeenCalledWith(input1[0], true); + }); + }); + }); +}); diff --git a/tfjs-converter/src/operations/op_list/string.ts b/tfjs-converter/src/operations/op_list/string.ts new file mode 100644 index 00000000000..e891af94dad --- /dev/null +++ b/tfjs-converter/src/operations/op_list/string.ts @@ -0,0 +1,32 @@ +import {OpMapper} from '../types'; + +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +export const json: OpMapper[] = [ + { + 'tfOpName': 'DecodeBase64', + 'category': 'string', + 'inputs': [{'start': 0, 'name': 'input', 'type': 'tensor'}] + }, + { + 'tfOpName': 'EncodeBase64', + 'category': 'string', + 'inputs': [{'start': 0, 'name': 'input', 'type': 'tensor'}], + 'attrs': [{'tfName': 'pad', 'name': 'pad', 'type': 'bool'}] + } +]; diff --git a/tfjs-converter/src/operations/types.ts b/tfjs-converter/src/operations/types.ts index cb5bf5ced69..6ca12f7fe5f 100644 --- a/tfjs-converter/src/operations/types.ts +++ b/tfjs-converter/src/operations/types.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google LLC. All Rights Reserved. + * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -25,7 +25,7 @@ export type ParamType = 'number'|'string'|'string[]'|'number[]'|'bool'|'bool[]'| export type Category = 'arithmetic'|'basic_math'|'control'|'convolution'|'custom'|'dynamic'| 'evaluation'|'image'|'creation'|'graph'|'logical'|'matrices'| - 'normalization'|'reduction'|'slice_join'|'spectral'|'transformation'; + 'normalization'|'reduction'|'slice_join'|'spectral'|'string'|'transformation'; // For mapping input or attributes of NodeDef into TensorFlow.js op param. export declare interface ParamMapper { diff --git a/tfjs-core/src/backends/backend.ts b/tfjs-core/src/backends/backend.ts index d5e9cd02b4b..9013217995a 100644 --- a/tfjs-core/src/backends/backend.ts +++ b/tfjs-core/src/backends/backend.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google Inc. All Rights Reserved. + * Copyright 2019 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -17,7 +17,7 @@ import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util'; import {FusedBatchMatMulConfig, FusedConv2DConfig} from '../ops/fused_util'; -import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor'; +import {Backend, DataId, Scalar, StringTensor, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor'; import {BackendValues, DataType, Rank, ShapeMap} from '../types'; export const EPSILON_FLOAT32 = 1e-7; @@ -659,6 +659,15 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer { dispose(): void { return notYetImplemented('dispose'); } + + encodeBase64(str: StringTensor|Tensor, pad = false): + T { + throw new Error('Not yet implemented'); + } + + decodeBase64(str: StringTensor|Tensor): T { + throw new Error('Not yet implemented'); + } } function notYetImplemented(kernelName: string): never { diff --git a/tfjs-core/src/backends/cpu/backend_cpu.ts b/tfjs-core/src/backends/cpu/backend_cpu.ts index f45855309a8..362bfc1faeb 100644 --- a/tfjs-core/src/backends/cpu/backend_cpu.ts +++ b/tfjs-core/src/backends/cpu/backend_cpu.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2017 Google Inc. All Rights Reserved. + * Copyright 2019 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -35,7 +35,7 @@ import {buffer, scalar, tensor, tensor4d} from '../../ops/ops'; import * as scatter_nd_util from '../../ops/scatter_nd_util'; import * as selu_util from '../../ops/selu_util'; import {computeFlatOffset, computeOutShape, isSliceContinous} from '../../ops/slice_util'; -import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor'; +import {DataId, Scalar, StringTensor, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor'; import {BackendValues, DataType, DataValues, NumericDataType, Rank, ShapeMap, TypedArray, upcastType} from '../../types'; import * as util from '../../util'; import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util'; @@ -44,6 +44,7 @@ import * as backend_util from '../backend_util'; import * as complex_util from '../complex_util'; import {nonMaxSuppressionV3} from '../non_max_suppression_impl'; import {split} from '../split_shared'; +import {decodeBase64Impl, encodeBase64Impl} from '../string_shared'; import {tile} from '../tile_impl'; import {topkImpl} from '../topk_impl'; import {whereImpl} from '../where_impl'; @@ -3552,6 +3553,17 @@ export class MathBackendCPU extends KernelBackend { dispose() {} + encodeBase64(str: StringTensor|Tensor, pad = false): + T { + const sVals = this.readSync(str.dataId) as Uint8Array[]; + return encodeBase64Impl(sVals, str.shape, pad); + } + + decodeBase64(str: StringTensor|Tensor): T { + const sVals = this.readSync(str.dataId) as Uint8Array[]; + return decodeBase64Impl(sVals, str.shape); + } + floatPrecision(): 16|32 { return 32; } diff --git a/tfjs-core/src/backends/string_shared.ts b/tfjs-core/src/backends/string_shared.ts new file mode 100644 index 00000000000..6990258b2b7 --- /dev/null +++ b/tfjs-core/src/backends/string_shared.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {arrayBufferToBase64String, base64StringToArrayBuffer, urlSafeBase64, urlUnsafeBase64} from '../io/io_utils'; +import {StringTensor} from '../tensor'; +import {decodeString} from '../util'; +import {ENGINE} from '../engine'; + +/** Shared implementation of the encodeBase64 kernel across WebGL and CPU. */ +export function encodeBase64Impl( + values: Uint8Array[], shape: number[], pad = false): T { + const resultValues = new Array(values.length); + + for (let i = 0; i < values.length; ++i) { + const bStr = arrayBufferToBase64String(values[i].buffer); + const bStrUrl = urlSafeBase64(bStr); + + if (pad) { + resultValues[i] = bStrUrl; + } else { + // Remove padding + resultValues[i] = bStrUrl.replace(/=/g, ''); + } + } + + return ENGINE.makeTensor(resultValues, shape, 'string') as T; +} + +/** Shared implementation of the decodeBase64 kernel across WebGL and CPU. */ +export function decodeBase64Impl( + values: Uint8Array[], shape: number[]): T { + const resultValues = new Array(values.length); + + for (let i = 0; i < values.length; ++i) { + // Undo URL safe and decode from Base64 to ArrayBuffer + const bStrUrl = decodeString(values[i]); + const bStr = urlUnsafeBase64(bStrUrl); + const aBuff = base64StringToArrayBuffer(bStr); + + resultValues[i] = decodeString(new Uint8Array(aBuff)); + } + + return ENGINE.makeTensor(resultValues, shape, 'string') as T; +} diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index a1dfe5c0086..b7cb19f3620 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2017 Google Inc. All Rights Reserved. + * Copyright 2019 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -38,7 +38,7 @@ import * as segment_util from '../../ops/segment_util'; import * as slice_util from '../../ops/slice_util'; import {softmax} from '../../ops/softmax'; import {range, scalar, tensor} from '../../ops/tensor_ops'; -import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor'; +import {DataId, Scalar, StringTensor, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor'; import {BackendValues, DataType, DataTypeMap, NumericDataType, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types'; import * as util from '../../util'; import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} from '../../util'; @@ -47,6 +47,7 @@ import * as backend_util from '../backend_util'; import {mergeRealAndImagArrays} from '../complex_util'; import {nonMaxSuppressionV3} from '../non_max_suppression_impl'; import {split} from '../split_shared'; +import {decodeBase64Impl, encodeBase64Impl} from '../string_shared'; import {tile} from '../tile_impl'; import {topkImpl} from '../topk_impl'; import {whereImpl} from '../where_impl'; @@ -2311,6 +2312,17 @@ export class MathBackendWebGL extends KernelBackend { return split(x, sizeSplits, axis); } + encodeBase64(str: StringTensor|Tensor, pad = false): + T { + const sVals = this.readSync(str.dataId) as Uint8Array[]; + return encodeBase64Impl(sVals, str.shape, pad); + } + + decodeBase64(str: StringTensor|Tensor): T { + const sVals = this.readSync(str.dataId) as Uint8Array[]; + return decodeBase64Impl(sVals, str.shape); + } + scatterND( indices: Tensor, updates: Tensor, shape: ShapeMap[R]): Tensor { const {sliceRank, numUpdates, sliceSize, strides, outputSize} = diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts index ac0fe3314bc..3bbcb408734 100644 --- a/tfjs-core/src/io/io_utils.ts +++ b/tfjs-core/src/io/io_utils.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google LLC. All Rights Reserved. + * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -340,3 +340,21 @@ export function getModelArtifactsInfoForJSON(modelArtifacts: ModelArtifacts): modelArtifacts.weightData.byteLength, }; } + +/** + * Make Base64 string URL safe by replacing `+` with `-` and `/` with `_`. + * + * @param str Base64 string to make URL safe. + */ +export function urlSafeBase64(str: string): string { + return str.replace(/\+/g, '-').replace(/\//g, '_'); +} + +/** + * Revert Base64 URL safe changes by replacing `-` with `+` and `_` with `/`. + * + * @param str URL safe Base string to revert changes. + */ +export function urlUnsafeBase64(str: string): string { + return str.replace(/-/g, '+').replace(/_/g, '/'); +} diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 84706177c93..50987740ec9 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google Inc. All Rights Reserved. + * Copyright 2019 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -55,6 +55,7 @@ export * from './gather_nd'; export * from './diag'; export * from './dropout'; export * from './signal_ops'; +export * from './string_ops'; export * from './in_top_k'; export {op} from './operation'; diff --git a/tfjs-core/src/ops/string_ops.ts b/tfjs-core/src/ops/string_ops.ts new file mode 100644 index 00000000000..503306b79e6 --- /dev/null +++ b/tfjs-core/src/ops/string_ops.ts @@ -0,0 +1,78 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ENGINE} from '../engine'; +import {StringTensor, Tensor} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; + +import {op} from './operation'; + +/** + * Encodes the values of a `tf.Tensor` (of dtype `string`) to Base64. + * + * Given a String tensor, returns a new tensor with the values encoded into + * web-safe base64 format. + * + * Web-safe means that the encoder uses `-` and `_` instead of `+` and `/`: + * + * en.wikipedia.org/wiki/Base64 + * + * ```js + * const x = tf.tensor1d(['Hello world!'], 'string'); + * + * x.encodeBase64().print(); + * ``` + * @param str The input `tf.Tensor` of dtype `string` to encode. + * @param pad Whether to add padding (`=`) to the end of the encoded string. + */ +/** @doc {heading: 'Operations', subheading: 'String'} */ +function encodeBase64_( + str: StringTensor|Tensor, pad = false): T { + const $str = convertToTensor(str, 'str', 'encodeBase64', 'string'); + + const backwardsFunc = (dy: T) => ({$str: () => decodeBase64(dy)}); + + return ENGINE.runKernelFunc( + backend => backend.encodeBase64($str, pad), {$str}, backwardsFunc); +} + +/** + * Decodes the values of a `tf.Tensor` (of dtype `string`) from Base64. + * + * Given a String tensor of Base64 encoded values, returns a new tensor with the + * decoded values. + * + * en.wikipedia.org/wiki/Base64 + * + * ```js + * const y = tf.scalar('SGVsbG8gd29ybGQh', 'string'); + * + * y.decodeBase64().print(); + * ``` + * @param str The input `tf.Tensor` of dtype `string` to decode. + */ +/** @doc {heading: 'Operations', subheading: 'String'} */ +function decodeBase64_(str: StringTensor|Tensor): T { + const $str = convertToTensor(str, 'str', 'decodeBase64', 'string'); + + const backwardsFunc = (dy: T) => ({$str: () => encodeBase64(dy)}); + + return ENGINE.runKernelFunc( + backend => backend.decodeBase64($str), {$str}, backwardsFunc); +} + +export const encodeBase64 = op({encodeBase64_}); +export const decodeBase64 = op({decodeBase64_}); diff --git a/tfjs-core/src/ops/string_ops_test.ts b/tfjs-core/src/ops/string_ops_test.ts new file mode 100644 index 00000000000..5eecc442d1d --- /dev/null +++ b/tfjs-core/src/ops/string_ops_test.ts @@ -0,0 +1,105 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysEqual} from '../test_util'; + +const txtArr = [ + 'Hello TensorFlow.js!', '𝌆', 'Pre\u2014trained models with Base64 ops\u002e', + 'how about these? 🌍💻🍕', 'https://www.tensorflow.org/js', 'àβÇdéf', + '你好, 世界', `Build, train, & deploy +ML models in JS` +]; +const urlSafeB64 = [ + 'SGVsbG8gVGVuc29yRmxvdy5qcyE', '8J2Mhg', + 'UHJl4oCUdHJhaW5lZCBtb2RlbHMgd2l0aCBCYXNlNjQgb3BzLg', + 'aG93IGFib3V0IHRoZXNlPyDwn4yN8J-Su_CfjZU', + 'aHR0cHM6Ly93d3cudGVuc29yZmxvdy5vcmcvanM', 'w6DOssOHZMOpZg', + '5L2g5aW9LCDkuJbnlYw', 'QnVpbGQsIHRyYWluLCAmIGRlcGxveQpNTCBtb2RlbHMgaW4gSlM' +]; +const urlSafeB64Pad = [ + 'SGVsbG8gVGVuc29yRmxvdy5qcyE=', '8J2Mhg==', + 'UHJl4oCUdHJhaW5lZCBtb2RlbHMgd2l0aCBCYXNlNjQgb3BzLg==', + 'aG93IGFib3V0IHRoZXNlPyDwn4yN8J-Su_CfjZU=', + 'aHR0cHM6Ly93d3cudGVuc29yZmxvdy5vcmcvanM=', 'w6DOssOHZMOpZg==', + '5L2g5aW9LCDkuJbnlYw=', 'QnVpbGQsIHRyYWluLCAmIGRlcGxveQpNTCBtb2RlbHMgaW4gSlM=' +]; + +describeWithFlags('encodeBase64', ALL_ENVS, () => { + it('scalar', async () => { + const a = tf.scalar(txtArr[1], 'string'); + const r = tf.encodeBase64(a); + expect(r.shape).toEqual([]); + expectArraysEqual(await r.data(), urlSafeB64[1]); + }); + it('1D padded', async () => { + const a = tf.tensor1d([txtArr[2]], 'string'); + const r = tf.encodeBase64(a, true); + expect(r.shape).toEqual([1]); + expectArraysEqual(await r.data(), [urlSafeB64Pad[2]]); + }); + it('2D', async () => { + const a = tf.tensor2d(txtArr, [2, 4], 'string'); + const r = tf.encodeBase64(a, false); + expect(r.shape).toEqual([2, 4]); + expectArraysEqual(await r.data(), urlSafeB64); + }); + it('3D padded', async () => { + const a = tf.tensor3d(txtArr, [2, 2, 2], 'string'); + const r = tf.encodeBase64(a, true); + expect(r.shape).toEqual([2, 2, 2]); + expectArraysEqual(await r.data(), urlSafeB64Pad); + }); +}); + +describeWithFlags('decodeBase64', ALL_ENVS, () => { + it('scalar', async () => { + const a = tf.scalar(urlSafeB64[1], 'string'); + const r = tf.decodeBase64(a); + expect(r.shape).toEqual([]); + expectArraysEqual(await r.data(), txtArr[1]); + }); + it('1D padded', async () => { + const a = tf.tensor1d([urlSafeB64Pad[2]], 'string'); + const r = tf.decodeBase64(a); + expect(r.shape).toEqual([1]); + expectArraysEqual(await r.data(), [txtArr[2]]); + }); + it('2D', async () => { + const a = tf.tensor2d(urlSafeB64, [2, 4], 'string'); + const r = tf.decodeBase64(a); + expect(r.shape).toEqual([2, 4]); + expectArraysEqual(await r.data(), txtArr); + }); + it('3D padded', async () => { + const a = tf.tensor3d(urlSafeB64Pad, [2, 2, 2], 'string'); + const r = tf.decodeBase64(a); + expect(r.shape).toEqual([2, 2, 2]); + expectArraysEqual(await r.data(), txtArr); + }); +}); + +describeWithFlags('encodeBase64-decodeBase64', ALL_ENVS, () => { + it('round-trip', async () => { + const s = [txtArr.join('')]; + const a = tf.tensor(s, [1], 'string'); + const b = tf.encodeBase64(a); + const c = tf.decodeBase64(b); + expectArraysEqual(await c.data(), s); + }); +}); diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index 27635c2d592..b55a045c8de 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2017 Google Inc. All Rights Reserved. + * Copyright 2019 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -367,6 +367,8 @@ export interface OpHandler { fft(x: Tensor): Tensor; ifft(x: Tensor): Tensor; rfft(x: Tensor): Tensor; irfft(x: Tensor): Tensor }; + encodeBase64(x: T, pad: boolean): T; + decodeBase64(x: T): T; } // For tracking tensor creation and disposal. @@ -1418,6 +1420,16 @@ export class Tensor { this.throwIfDisposed(); return opHandler.spectral.irfft(this); } + + encodeBase64(this: T, pad = false): T { + this.throwIfDisposed(); + return opHandler.encodeBase64(this, pad); + } + + decodeBase64(this: T): T { + this.throwIfDisposed(); + return opHandler.decodeBase64(this); + } } Object.defineProperty(Tensor, Symbol.hasInstance, { value: (instance: Tensor) => { diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index c1dde0528d8..3a4af447359 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -90,6 +90,7 @@ import './ops/softmax_test'; import './ops/sparse_to_dense_test'; import './ops/spectral_ops_test'; import './ops/strided_slice_test'; +import './ops/string_ops_test'; import './ops/topk_test'; import './ops/transpose_test'; import './ops/unary_ops_test'; diff --git a/tfjs-node/src/nodejs_kernel_backend.ts b/tfjs-node/src/nodejs_kernel_backend.ts index 0545aee0fdc..70c1e59e156 100644 --- a/tfjs-node/src/nodejs_kernel_backend.ts +++ b/tfjs-node/src/nodejs_kernel_backend.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google Inc. All Rights Reserved. + * Copyright 2019 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -24,6 +24,8 @@ import {FusedBatchMatMulConfig, FusedConv2DConfig} from '@tensorflow/tfjs-core/d import {isArray, isNullOrUndefined} from 'util'; import {Int64Scalar} from './int64_tensors'; +// tslint:disable-next-line: no-imports-from-dist +import { StringTensor } from '@tensorflow/tfjs-core/dist/tensor'; import {TensorMetadata, TFEOpAttr, TFJSBinding} from './tfjs_binding'; type TensorData = { @@ -1751,6 +1753,20 @@ export class NodeJSKernelBackend extends KernelBackend { return this.executeSingleOutput('LinSpace', opAttrs, inputs) as Tensor1D; } + encodeBase64(str: StringTensor|Tensor, pad = false): + T { + const opAttrs = + [{name: 'pad', type: this.binding.TF_ATTR_BOOL, value: pad}]; + return this.executeSingleOutput('EncodeBase64', opAttrs, [str as Tensor]) as + T; + } + + decodeBase64(str: StringTensor|Tensor): T { + const opAttrs: TFEOpAttr[] = []; + return this.executeSingleOutput('DecodeBase64', opAttrs, [str as Tensor]) as + T; + } + decodeJpeg( contents: Uint8Array, channels: number, ratio: number, fancyUpscaling: boolean, tryRecoverTruncated: boolean,