diff --git a/realbook/__init__.py b/realbook/__init__.py index 193e4ea..45d06f5 100644 --- a/realbook/__init__.py +++ b/realbook/__init__.py @@ -16,7 +16,7 @@ # limitations under the License. __author__ = "Spotify" -__version__ = "1.0.2" +__version__ = "1.0.3" __email__ = "realbook@spotify.com" __description__ = "Python libraries for easier machine learning on audio" __url__ = "https://github.com/spotify/realbook" diff --git a/realbook/layers/compatibility.py b/realbook/layers/compatibility.py index 83ae9fe..836fdb7 100644 --- a/realbook/layers/compatibility.py +++ b/realbook/layers/compatibility.py @@ -21,10 +21,11 @@ MessageToDict as SerializeProtobufToDict, ParseDict as ParseDictToProtobuf, ) -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, cast, Dict, List, Optional, Tuple, Union from tensorflow.python.framework.convert_to_constants import ( convert_variables_to_constants_v2_as_graph, ) +from tensorflow.python.framework.importer import _IsControlInput as is_control_input from tensorflow.python.ops.op_selector import UnliftableError from tensorflow.python.eager.wrap_function import WrappedFunction @@ -227,6 +228,7 @@ def get_saved_model_output_tensors(saved_model_or_path: Union[tf.keras.Model, st def create_function_from_tensors( input_tensors: Union[tf.Tensor, List[tf.Tensor]], output_tensors: Union[tf.Tensor, List[tf.Tensor]], + include_control_inputs: bool = False, ) -> WrappedFunction: """ Given two lists of tensors (input and output), this method will return a tf.function @@ -269,9 +271,17 @@ def create_function_from_tensors( graph_input_names = [t.name for t in graph.inputs] + graph_def = graph.as_graph_def() + + if not include_control_inputs: + # If this graph has any control inputs in it, those inputs will + # likely not be convertible (nor do we want them in our converted model!) + for node in graph_def.node: + node.input[:] = [tensor_name for tensor_name in node.input if not is_control_input(tensor_name)] + try: return _load_concrete_function_from_graph_def( - graph.as_graph_def(), + graph_def, [t.name for t in input_tensors], [t.name for t in output_tensors], ) @@ -310,3 +320,40 @@ def call(self, _input: Union[tf.Tensor, List[tf.Tensor]]) -> Union[tf.Tensor, Li return self.func(*_input) else: return self.func(_input) + + +def dump_saved_model_to_graph( + saved_model_or_path: Union[tf.keras.Model, str], + include_control_inputs: bool = False, +) -> bytes: + """ + Given a path to a SavedModel or an already loaded SavedModel, + render that SavedModel's default serving signature as a V1 + TensorFlow .pb file, which is more easily visualized with tools like + Netron (https://netron.app). + + The result of this function is a serialized GraphDef protobuf, + which can be saved to a .pb file directly: + + ``` + with open("graph-to-visualize.pb", "wb") as f: + f.write(dump_saved_model_to_graph("my_model.savedmodel")) + ``` + """ + if isinstance(saved_model_or_path, str): + savedmodel = tf.saved_model.load(saved_model_or_path) + model = savedmodel.signatures["serving_default"] + model._backref = savedmodel # Without this, the SavedModel will be GC'd too early + else: + model = saved_model_or_path + if hasattr(model, "signatures"): + model = model.signatures["serving_default"] + _, graph_def = convert_variables_to_constants_v2_as_graph(model) + + if not include_control_inputs: + # If this graph has any control inputs in it, those inputs will + # likely not be convertible (nor do we want them in our converted model!) + for node in graph_def.node: + node.input[:] = [tensor_name for tensor_name in node.input if not is_control_input(tensor_name)] + + return cast(bytes, graph_def.SerializeToString()) diff --git a/tests/layers/test_compatibility.py b/tests/layers/test_compatibility.py index 9289a0b..66a3f21 100644 --- a/tests/layers/test_compatibility.py +++ b/tests/layers/test_compatibility.py @@ -21,7 +21,9 @@ import tensorflow as tf from tensorflow.keras.layers import Dense, Concatenate import numpy as np +import pytest from contextlib import contextmanager +from tensorflow.python.framework.importer import _IsControlInput as is_control_input from tensorflow.python.framework.convert_to_constants import ( convert_variables_to_constants_v2_as_graph, ) @@ -31,6 +33,7 @@ FrozenGraphLayer, SavedModelLayer, get_all_tensors_from_saved_model, + create_function_from_tensors, TensorWrapperLayer, ) @@ -285,6 +288,26 @@ def test_tensor_wrapper_layer_multiple_inputs() -> None: tensors = get_all_tensors_from_saved_model(saved_model_path) layer = TensorWrapperLayer(tensors[0].graph.inputs, tensors[0].graph.outputs) - outputs = layer([value for _name, value in sorted(x.items(), key=lambda t: t[0])]) # type: ignore + outputs = layer([value for _name, value in sorted(x.items(), key=lambda t: t[0])]) for expected, actual in zip(expected_outputs, outputs): assert np.allclose(actual, expected) + + +@pytest.mark.parametrize("include_control_inputs", [False, True]) +def test_create_function_from_tensors(include_control_inputs: bool) -> None: + model = train_addition_model() + x = np.random.rand(1, model.input_shape[-1]).astype(np.float32) + expected_output = model.predict(x) + + with keras_model_to_savedmodel(model) as saved_model_path: + tensors = get_all_tensors_from_saved_model(saved_model_path) + control_inputs: List[tf.Operation] = sum([op.control_inputs for op in set([t.op for t in tensors])], []) + assert control_inputs + fun = create_function_from_tensors(tensors[0], tensors[-1], include_control_inputs=include_control_inputs) + control_inputs_in_graph = [node.name for node in fun.graph.as_graph_def().node if is_control_input(node.name)] + + # TODO(psobot): How do we generate a model with control inputs in TF 2.10+? + # Every attempt I made to generate such a model fails. + if not include_control_inputs: + assert not control_inputs_in_graph + assert np.allclose(fun(tf.constant(x)), expected_output)