Skip to content

Commit

Permalink
Strip control inputs from graph when creating new TF functions from t…
Browse files Browse the repository at this point in the history
…ensors. (#5)

* Strip control inputs from graph when creating new TF functions from tensors.

* Formatting

* Appease MyPy.

* Add dump_saved_model_to_graph

* Appease MyPy, again

* Add test for create_function_from_tensors.

* Formatting.

* Appease mypy for a third time.
  • Loading branch information
psobot authored Jul 24, 2023
1 parent 7ab7850 commit 26e62b3
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 4 deletions.
2 changes: 1 addition & 1 deletion realbook/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# limitations under the License.

__author__ = "Spotify"
__version__ = "1.0.2"
__version__ = "1.0.3"
__email__ = "[email protected]"
__description__ = "Python libraries for easier machine learning on audio"
__url__ = "https://github.com/spotify/realbook"
51 changes: 49 additions & 2 deletions realbook/layers/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
MessageToDict as SerializeProtobufToDict,
ParseDict as ParseDictToProtobuf,
)
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, cast, Dict, List, Optional, Tuple, Union
from tensorflow.python.framework.convert_to_constants import (
convert_variables_to_constants_v2_as_graph,
)
from tensorflow.python.framework.importer import _IsControlInput as is_control_input
from tensorflow.python.ops.op_selector import UnliftableError
from tensorflow.python.eager.wrap_function import WrappedFunction

Expand Down Expand Up @@ -227,6 +228,7 @@ def get_saved_model_output_tensors(saved_model_or_path: Union[tf.keras.Model, st
def create_function_from_tensors(
input_tensors: Union[tf.Tensor, List[tf.Tensor]],
output_tensors: Union[tf.Tensor, List[tf.Tensor]],
include_control_inputs: bool = False,
) -> WrappedFunction:
"""
Given two lists of tensors (input and output), this method will return a tf.function
Expand Down Expand Up @@ -269,9 +271,17 @@ def create_function_from_tensors(

graph_input_names = [t.name for t in graph.inputs]

graph_def = graph.as_graph_def()

if not include_control_inputs:
# If this graph has any control inputs in it, those inputs will
# likely not be convertible (nor do we want them in our converted model!)
for node in graph_def.node:
node.input[:] = [tensor_name for tensor_name in node.input if not is_control_input(tensor_name)]

try:
return _load_concrete_function_from_graph_def(
graph.as_graph_def(),
graph_def,
[t.name for t in input_tensors],
[t.name for t in output_tensors],
)
Expand Down Expand Up @@ -310,3 +320,40 @@ def call(self, _input: Union[tf.Tensor, List[tf.Tensor]]) -> Union[tf.Tensor, Li
return self.func(*_input)
else:
return self.func(_input)


def dump_saved_model_to_graph(
saved_model_or_path: Union[tf.keras.Model, str],
include_control_inputs: bool = False,
) -> bytes:
"""
Given a path to a SavedModel or an already loaded SavedModel,
render that SavedModel's default serving signature as a V1
TensorFlow .pb file, which is more easily visualized with tools like
Netron (https://netron.app).
The result of this function is a serialized GraphDef protobuf,
which can be saved to a .pb file directly:
```
with open("graph-to-visualize.pb", "wb") as f:
f.write(dump_saved_model_to_graph("my_model.savedmodel"))
```
"""
if isinstance(saved_model_or_path, str):
savedmodel = tf.saved_model.load(saved_model_or_path)
model = savedmodel.signatures["serving_default"]
model._backref = savedmodel # Without this, the SavedModel will be GC'd too early
else:
model = saved_model_or_path
if hasattr(model, "signatures"):
model = model.signatures["serving_default"]
_, graph_def = convert_variables_to_constants_v2_as_graph(model)

if not include_control_inputs:
# If this graph has any control inputs in it, those inputs will
# likely not be convertible (nor do we want them in our converted model!)
for node in graph_def.node:
node.input[:] = [tensor_name for tensor_name in node.input if not is_control_input(tensor_name)]

return cast(bytes, graph_def.SerializeToString())
25 changes: 24 additions & 1 deletion tests/layers/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import tensorflow as tf
from tensorflow.keras.layers import Dense, Concatenate
import numpy as np
import pytest
from contextlib import contextmanager
from tensorflow.python.framework.importer import _IsControlInput as is_control_input
from tensorflow.python.framework.convert_to_constants import (
convert_variables_to_constants_v2_as_graph,
)
Expand All @@ -31,6 +33,7 @@
FrozenGraphLayer,
SavedModelLayer,
get_all_tensors_from_saved_model,
create_function_from_tensors,
TensorWrapperLayer,
)

Expand Down Expand Up @@ -285,6 +288,26 @@ def test_tensor_wrapper_layer_multiple_inputs() -> None:
tensors = get_all_tensors_from_saved_model(saved_model_path)
layer = TensorWrapperLayer(tensors[0].graph.inputs, tensors[0].graph.outputs)

outputs = layer([value for _name, value in sorted(x.items(), key=lambda t: t[0])]) # type: ignore
outputs = layer([value for _name, value in sorted(x.items(), key=lambda t: t[0])])
for expected, actual in zip(expected_outputs, outputs):
assert np.allclose(actual, expected)


@pytest.mark.parametrize("include_control_inputs", [False, True])
def test_create_function_from_tensors(include_control_inputs: bool) -> None:
model = train_addition_model()
x = np.random.rand(1, model.input_shape[-1]).astype(np.float32)
expected_output = model.predict(x)

with keras_model_to_savedmodel(model) as saved_model_path:
tensors = get_all_tensors_from_saved_model(saved_model_path)
control_inputs: List[tf.Operation] = sum([op.control_inputs for op in set([t.op for t in tensors])], [])
assert control_inputs
fun = create_function_from_tensors(tensors[0], tensors[-1], include_control_inputs=include_control_inputs)
control_inputs_in_graph = [node.name for node in fun.graph.as_graph_def().node if is_control_input(node.name)]

# TODO(psobot): How do we generate a model with control inputs in TF 2.10+?
# Every attempt I made to generate such a model fails.
if not include_control_inputs:
assert not control_inputs_in_graph
assert np.allclose(fun(tf.constant(x)), expected_output)

0 comments on commit 26e62b3

Please sign in to comment.