Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EncodeBase64 and DecodeBase64 ops #2004

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tfjs-converter/docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,13 @@
|Not mapped|ifft|
|Not mapped|rfft|

## Operations - Strings

|Tensorflow Op Name|Tensorflow.js Op Name|
|---|---|
|DecodeBase64|decodeBase64|
|EncodeBase64|encodeBase64|

## Tensors - Transformations

|Tensorflow Op Name|Tensorflow.js Op Name|
Expand Down
31 changes: 31 additions & 0 deletions tfjs-converter/python/tensorflowjs/op_list/string.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
[
{
"tfOpName": "DecodeBase64",
"category": "string",
"inputs": [
{
"start": 0,
"name": "input",
"type": "tensor"
}
]
},
{
"tfOpName": "EncodeBase64",
"category": "string",
"inputs": [
{
"start": 0,
"name": "input",
"type": "tensor"
}
],
"attrs": [
{
"tfName": "pad",
"name": "pad",
"type": "bool"
}
]
}
]
46 changes: 46 additions & 0 deletions tfjs-converter/src/operations/executors/string_executor.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tfc from '@tensorflow/tfjs-core';

import {NamedTensorsMap} from '../../data/types';
import {ExecutionContext} from '../../executor/execution_context';
import {InternalOpExecutor, Node} from '../types';

import {getParamValue} from './utils';

export let executeOp: InternalOpExecutor =
(node: Node, tensorMap: NamedTensorsMap,
context: ExecutionContext): tfc.Tensor[] => {
switch (node.op) {
case 'DecodeBase64': {
const input =
getParamValue('str', node, tensorMap, context) as tfc.Tensor;
return [tfc.decodeBase64(input)];
}
case 'EncodeBase64': {
const input =
getParamValue('str', node, tensorMap, context) as tfc.Tensor;
const pad = getParamValue('pad', node, tensorMap, context) as boolean;
return [tfc.encodeBase64(input, pad)];
}
default:
throw TypeError(`Node type ${node.op} is not implemented`);
}
};

export const CATEGORY = 'string';
62 changes: 62 additions & 0 deletions tfjs-converter/src/operations/executors/string_executor_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tfc from '@tensorflow/tfjs-core';

import {ExecutionContext} from '../../executor/execution_context';
import {Node} from '../types';

import {executeOp} from './string_executor';
import {createBoolAttr, createTensorAttr} from './test_helper';

describe('string', () => {
let node: Node;
const input1 = [tfc.tensor(['a'], [1], 'string')];
const context = new ExecutionContext({}, {});

beforeEach(() => {
node = {
name: 'test',
op: '',
category: 'string',
inputNames: ['input1'],
inputs: [],
inputParams: {str: createTensorAttr(0)},
attrParams: {},
children: []
};
});

describe('executeOp', () => {
describe('DecodeBase64', () => {
it('should call tfc.decodeBase64', () => {
spyOn(tfc, 'decodeBase64');
node.op = 'DecodeBase64';
executeOp(node, {input1}, context);
expect(tfc.decodeBase64).toHaveBeenCalledWith(input1[0]);
});
});
describe('EncodeBase64', () => {
it('should call tfc.encodeBase64', () => {
spyOn(tfc, 'encodeBase64');
node.op = 'EncodeBase64';
node.attrParams.pad = createBoolAttr(true);
executeOp(node, {input1}, context);
expect(tfc.encodeBase64).toHaveBeenCalledWith(input1[0], true);
});
});
});
});
32 changes: 32 additions & 0 deletions tfjs-converter/src/operations/op_list/string.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import {OpMapper} from '../types';

/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

export const json: OpMapper[] = [
{
'tfOpName': 'DecodeBase64',
'category': 'string',
'inputs': [{'start': 0, 'name': 'input', 'type': 'tensor'}]
},
{
'tfOpName': 'EncodeBase64',
'category': 'string',
'inputs': [{'start': 0, 'name': 'input', 'type': 'tensor'}],
'attrs': [{'tfName': 'pad', 'name': 'pad', 'type': 'bool'}]
}
];
4 changes: 2 additions & 2 deletions tfjs-converter/src/operations/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
Expand All @@ -25,7 +25,7 @@ export type ParamType = 'number'|'string'|'string[]'|'number[]'|'bool'|'bool[]'|
export type Category =
'arithmetic'|'basic_math'|'control'|'convolution'|'custom'|'dynamic'|
'evaluation'|'image'|'creation'|'graph'|'logical'|'matrices'|
'normalization'|'reduction'|'slice_join'|'spectral'|'transformation';
'normalization'|'reduction'|'slice_join'|'spectral'|'string'|'transformation';

// For mapping input or attributes of NodeDef into TensorFlow.js op param.
export declare interface ParamMapper {
Expand Down
13 changes: 11 additions & 2 deletions tfjs-core/src/backends/backend.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
Expand All @@ -17,7 +17,7 @@

import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
import {FusedBatchMatMulConfig, FusedConv2DConfig} from '../ops/fused_util';
import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
import {Backend, DataId, Scalar, StringTensor, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
import {BackendValues, DataType, Rank, ShapeMap} from '../types';

export const EPSILON_FLOAT32 = 1e-7;
Expand Down Expand Up @@ -659,6 +659,15 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
dispose(): void {
return notYetImplemented('dispose');
}

encodeBase64<T extends StringTensor>(str: StringTensor|Tensor, pad = false):
T {
throw new Error('Not yet implemented');
}

decodeBase64<T extends StringTensor>(str: StringTensor|Tensor): T {
throw new Error('Not yet implemented');
}
}

function notYetImplemented(kernelName: string): never {
Expand Down
16 changes: 14 additions & 2 deletions tfjs-core/src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
Expand Down Expand Up @@ -35,7 +35,7 @@ import {buffer, scalar, tensor, tensor4d} from '../../ops/ops';
import * as scatter_nd_util from '../../ops/scatter_nd_util';
import * as selu_util from '../../ops/selu_util';
import {computeFlatOffset, computeOutShape, isSliceContinous} from '../../ops/slice_util';
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor';
import {DataId, Scalar, StringTensor, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor';
import {BackendValues, DataType, DataValues, NumericDataType, Rank, ShapeMap, TypedArray, upcastType} from '../../types';
import * as util from '../../util';
import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util';
Expand All @@ -44,6 +44,7 @@ import * as backend_util from '../backend_util';
import * as complex_util from '../complex_util';
import {nonMaxSuppressionV3} from '../non_max_suppression_impl';
import {split} from '../split_shared';
import {decodeBase64Impl, encodeBase64Impl} from '../string_shared';
import {tile} from '../tile_impl';
import {topkImpl} from '../topk_impl';
import {whereImpl} from '../where_impl';
Expand Down Expand Up @@ -3552,6 +3553,17 @@ export class MathBackendCPU extends KernelBackend {

dispose() {}

encodeBase64<T extends StringTensor>(str: StringTensor|Tensor, pad = false):
T {
const sVals = this.readSync(str.dataId) as Uint8Array[];
return encodeBase64Impl(sVals, str.shape, pad);
}

decodeBase64<T extends StringTensor>(str: StringTensor|Tensor): T {
const sVals = this.readSync(str.dataId) as Uint8Array[];
return decodeBase64Impl(sVals, str.shape);
}

floatPrecision(): 16|32 {
return 32;
}
Expand Down
58 changes: 58 additions & 0 deletions tfjs-core/src/backends/string_shared.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {arrayBufferToBase64String, base64StringToArrayBuffer, urlSafeBase64, urlUnsafeBase64} from '../io/io_utils';
import {StringTensor} from '../tensor';
import {decodeString} from '../util';
import {ENGINE} from '../engine';

/** Shared implementation of the encodeBase64 kernel across WebGL and CPU. */
export function encodeBase64Impl<T extends StringTensor>(
values: Uint8Array[], shape: number[], pad = false): T {
const resultValues = new Array(values.length);

for (let i = 0; i < values.length; ++i) {
const bStr = arrayBufferToBase64String(values[i].buffer);
const bStrUrl = urlSafeBase64(bStr);

if (pad) {
resultValues[i] = bStrUrl;
} else {
// Remove padding
resultValues[i] = bStrUrl.replace(/=/g, '');
}
}

return ENGINE.makeTensor(resultValues, shape, 'string') as T;
}

/** Shared implementation of the decodeBase64 kernel across WebGL and CPU. */
export function decodeBase64Impl<T extends StringTensor>(
values: Uint8Array[], shape: number[]): T {
const resultValues = new Array(values.length);

for (let i = 0; i < values.length; ++i) {
// Undo URL safe and decode from Base64 to ArrayBuffer
const bStrUrl = decodeString(values[i]);
const bStr = urlUnsafeBase64(bStrUrl);
const aBuff = base64StringToArrayBuffer(bStr);

resultValues[i] = decodeString(new Uint8Array(aBuff));
}

return ENGINE.makeTensor(resultValues, shape, 'string') as T;
}
16 changes: 14 additions & 2 deletions tfjs-core/src/backends/webgl/backend_webgl.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
Expand Down Expand Up @@ -38,7 +38,7 @@ import * as segment_util from '../../ops/segment_util';
import * as slice_util from '../../ops/slice_util';
import {softmax} from '../../ops/softmax';
import {range, scalar, tensor} from '../../ops/tensor_ops';
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor';
import {DataId, Scalar, StringTensor, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor';
import {BackendValues, DataType, DataTypeMap, NumericDataType, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types';
import * as util from '../../util';
import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} from '../../util';
Expand All @@ -47,6 +47,7 @@ import * as backend_util from '../backend_util';
import {mergeRealAndImagArrays} from '../complex_util';
import {nonMaxSuppressionV3} from '../non_max_suppression_impl';
import {split} from '../split_shared';
import {decodeBase64Impl, encodeBase64Impl} from '../string_shared';
import {tile} from '../tile_impl';
import {topkImpl} from '../topk_impl';
import {whereImpl} from '../where_impl';
Expand Down Expand Up @@ -2311,6 +2312,17 @@ export class MathBackendWebGL extends KernelBackend {
return split(x, sizeSplits, axis);
}

encodeBase64<T extends StringTensor>(str: StringTensor|Tensor, pad = false):
T {
const sVals = this.readSync(str.dataId) as Uint8Array[];
return encodeBase64Impl(sVals, str.shape, pad);
}

decodeBase64<T extends StringTensor>(str: StringTensor|Tensor): T {
const sVals = this.readSync(str.dataId) as Uint8Array[];
return decodeBase64Impl(sVals, str.shape);
}

scatterND<R extends Rank>(
indices: Tensor, updates: Tensor, shape: ShapeMap[R]): Tensor<R> {
const {sliceRank, numUpdates, sliceSize, strides, outputSize} =
Expand Down
Loading