diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index da1be96f012..e5175225bf0 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -54,7 +54,6 @@ jobs: fi pip install -r $REQUIREMENTS_FILE --progress-bar off --upgrade pip uninstall -y keras keras-nightly - pip install tf_keras==2.16.0 --progress-bar off --upgrade pip install -e "." --progress-bar off --upgrade - name: Test applications with pytest if: ${{ steps.filter.outputs.applications == 'true' }} diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 8d9717466db..dec1d407f41 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -48,7 +48,7 @@ jobs: # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF # format to the repository Actions tab. - name: "Upload artifact" - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0 with: name: SARIF file path: results.sarif @@ -56,6 +56,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@f09c1c0a94de965c15400f5634aa42fac8fb8f88 # v3.27.5 + uses: github/codeql-action/upload-sarif@48ab28a6f5dbc2a99bf1e0131198dd8f1df78169 # v3.28.0 with: sarif_file: results.sarif diff --git a/guides/functional_api.py b/guides/functional_api.py index 7dbbfbbbe61..c174953179e 100644 --- a/guides/functional_api.py +++ b/guides/functional_api.py @@ -179,6 +179,7 @@ from this file, even if the code that built the model is no longer available. This saved file includes the: + - model architecture - model weight values (that were learned during training) - model training config, if any (as passed to `compile()`) diff --git a/keras/api/_tf_keras/keras/export/__init__.py b/keras/api/_tf_keras/keras/export/__init__.py index 68fa6029396..49f7a66972b 100644 --- a/keras/api/_tf_keras/keras/export/__init__.py +++ b/keras/api/_tf_keras/keras/export/__init__.py @@ -4,4 +4,4 @@ since your modifications would be overwritten. """ -from keras.src.export.export_lib import ExportArchive +from keras.src.export.saved_model import ExportArchive diff --git a/keras/api/_tf_keras/keras/layers/__init__.py b/keras/api/_tf_keras/keras/layers/__init__.py index 71db20bf394..4f13a596130 100644 --- a/keras/api/_tf_keras/keras/layers/__init__.py +++ b/keras/api/_tf_keras/keras/layers/__init__.py @@ -4,7 +4,7 @@ since your modifications would be overwritten. """ -from keras.src.export.export_lib import TFSMLayer +from keras.src.export.tfsm_layer import TFSMLayer from keras.src.layers import deserialize from keras.src.layers import serialize from keras.src.layers.activations.activation import Activation @@ -155,6 +155,12 @@ from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( RandomBrightness, ) +from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import ( + RandomColorDegeneration, +) +from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import ( + RandomColorJitter, +) from keras.src.layers.preprocessing.image_preprocessing.random_contrast import ( RandomContrast, ) @@ -170,12 +176,21 @@ from keras.src.layers.preprocessing.image_preprocessing.random_hue import ( RandomHue, ) +from keras.src.layers.preprocessing.image_preprocessing.random_posterization import ( + RandomPosterization, +) from keras.src.layers.preprocessing.image_preprocessing.random_rotation import ( RandomRotation, ) from keras.src.layers.preprocessing.image_preprocessing.random_saturation import ( RandomSaturation, ) +from keras.src.layers.preprocessing.image_preprocessing.random_sharpness import ( + RandomSharpness, +) +from keras.src.layers.preprocessing.image_preprocessing.random_shear import ( + RandomShear, +) from keras.src.layers.preprocessing.image_preprocessing.random_translation import ( RandomTranslation, ) diff --git a/keras/api/export/__init__.py b/keras/api/export/__init__.py index 68fa6029396..49f7a66972b 100644 --- a/keras/api/export/__init__.py +++ b/keras/api/export/__init__.py @@ -4,4 +4,4 @@ since your modifications would be overwritten. """ -from keras.src.export.export_lib import ExportArchive +from keras.src.export.saved_model import ExportArchive diff --git a/keras/api/layers/__init__.py b/keras/api/layers/__init__.py index 4c31ded2375..a4aaf7c9917 100644 --- a/keras/api/layers/__init__.py +++ b/keras/api/layers/__init__.py @@ -4,7 +4,7 @@ since your modifications would be overwritten. """ -from keras.src.export.export_lib import TFSMLayer +from keras.src.export.tfsm_layer import TFSMLayer from keras.src.layers import deserialize from keras.src.layers import serialize from keras.src.layers.activations.activation import Activation @@ -155,6 +155,12 @@ from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( RandomBrightness, ) +from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import ( + RandomColorDegeneration, +) +from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import ( + RandomColorJitter, +) from keras.src.layers.preprocessing.image_preprocessing.random_contrast import ( RandomContrast, ) @@ -170,12 +176,21 @@ from keras.src.layers.preprocessing.image_preprocessing.random_hue import ( RandomHue, ) +from keras.src.layers.preprocessing.image_preprocessing.random_posterization import ( + RandomPosterization, +) from keras.src.layers.preprocessing.image_preprocessing.random_rotation import ( RandomRotation, ) from keras.src.layers.preprocessing.image_preprocessing.random_saturation import ( RandomSaturation, ) +from keras.src.layers.preprocessing.image_preprocessing.random_sharpness import ( + RandomSharpness, +) +from keras.src.layers.preprocessing.image_preprocessing.random_shear import ( + RandomShear, +) from keras.src.layers.preprocessing.image_preprocessing.random_translation import ( RandomTranslation, ) diff --git a/keras/src/backend/common/variables.py b/keras/src/backend/common/variables.py index 2d067f1ab89..db10cedabaa 100644 --- a/keras/src/backend/common/variables.py +++ b/keras/src/backend/common/variables.py @@ -33,9 +33,11 @@ class Variable: autocast: Optional. Boolean indicating whether the variable supports autocasting. If `True`, the layer may first convert the variable to the compute data type when accessed. Defaults to `True`. - aggregation: Optional. String specifying how a distributed variable will - be aggregated. This serves as a semantic annotation, to be taken - into account by downstream backends or users. Defaults to `"mean"`. + aggregation: Optional string, one of `None`, `"none"`, `"mean"`, + `"sum"` or `"only_first_replica"` specifying how a distributed + variable will be aggregated. This serves as a semantic annotation, + to be taken into account by downstream backends or users. Defaults + to `"none"`. name: Optional. A unique name for the variable. Automatically generated if not set. @@ -93,7 +95,7 @@ def __init__( dtype=None, trainable=True, autocast=True, - aggregation="mean", + aggregation="none", name=None, ): name = name or auto_name(self.__class__.__name__) @@ -103,12 +105,21 @@ def __init__( "cannot contain character `/`. " f"Received: name={name}" ) - if aggregation not in ("none", "mean", "sum", "only_first_replica"): + if aggregation not in ( + None, + "none", + "mean", + "sum", + "only_first_replica", + ): raise ValueError( "Invalid valid for argument `aggregation`. Expected " - "one of {'none', 'mean', 'sum', 'only_first_replica'}. " + "one of `None`, `'none'`, `'mean'`, `'sum'`, " + "`'only_first_replica'`. " f"Received: aggregation={aggregation}" ) + if aggregation is None: + aggregation = "none" self._name = name parent_path = current_path() if parent_path: diff --git a/keras/src/backend/tensorflow/distribute_test.py b/keras/src/backend/tensorflow/distribute_test.py index 3c29777c582..195eb999d35 100644 --- a/keras/src/backend/tensorflow/distribute_test.py +++ b/keras/src/backend/tensorflow/distribute_test.py @@ -130,8 +130,8 @@ def test_variable_aggregation(self): with strategy.scope(): x = np.random.random((4, 4)) v1 = backend.Variable(x, dtype="float32") - self.assertEqual(v1.aggregation, "mean") - self.assertEqual(v1.value.aggregation, tf.VariableAggregation.MEAN) + self.assertEqual(v1.aggregation, "none") + self.assertEqual(v1.value.aggregation, tf.VariableAggregation.NONE) v2 = backend.Variable(x, dtype="float32", aggregation="sum") self.assertEqual(v2.aggregation, "sum") diff --git a/keras/src/backend/torch/export.py b/keras/src/backend/torch/export.py index 55fc68ed954..7de5653e9fb 100644 --- a/keras/src/backend/torch/export.py +++ b/keras/src/backend/torch/export.py @@ -1,35 +1,128 @@ -from keras.src import layers +import copy +import warnings + +import torch + from keras.src import tree +from keras.src.export.export_utils import convert_spec_to_tensor +from keras.src.utils.module_utils import tensorflow as tf +from keras.src.utils.module_utils import torch_xla class TorchExportArchive: def track(self, resource): - if not isinstance(resource, layers.Layer): - raise ValueError( - "Invalid resource type. Expected an instance of a " - "JAX-based Keras `Layer` or `Model`. " - f"Received instead an object of type '{type(resource)}'. " - f"Object received: {resource}" - ) + raise NotImplementedError( + "`track` is not implemented in the torch backend. Use" + "`track_and_add_endpoint` instead." + ) - if isinstance(resource, layers.Layer): - # Variables in the lists below are actually part of the trackables - # that get saved, because the lists are created in __init__. - variables = resource.variables - trainable_variables = resource.trainable_variables - non_trainable_variables = resource.non_trainable_variables - self._tf_trackable.variables += tree.map_structure( - self._convert_to_tf_variable, variables - ) - self._tf_trackable.trainable_variables += tree.map_structure( - self._convert_to_tf_variable, trainable_variables - ) - self._tf_trackable.non_trainable_variables += tree.map_structure( - self._convert_to_tf_variable, non_trainable_variables + def add_endpoint(self, name, fn, input_signature, **kwargs): + raise NotImplementedError( + "`add_endpoint` is not implemented in the torch backend. Use" + "`track_and_add_endpoint` instead." + ) + + def track_and_add_endpoint(self, name, resource, input_signature, **kwargs): + # Disable false alarms related to lifting parameters. + warnings.filterwarnings("ignore", message=".*created when tracing.*") + warnings.filterwarnings( + "ignore", message=".*Unable to find the path of the module.*" + ) + + if not isinstance(resource, torch.nn.Module): + raise TypeError( + "`resource` must be an instance of `torch.nn.Module`. " + f"Received: resource={resource} (of type {type(resource)})" ) - def add_endpoint(self, name, fn, input_signature=None, **kwargs): - # TODO: torch-xla? - raise NotImplementedError( - "`add_endpoint` is not implemented in the torch backend." + sample_inputs = tree.map_structure( + lambda x: convert_spec_to_tensor(x, replace_none_number=1), + input_signature, + ) + sample_inputs = tuple(sample_inputs) + + # Ref: torch_xla.tf_saved_model_integration + # TODO: Utilize `dynamic_shapes` + exported = torch.export.export( + resource, sample_inputs, dynamic_shapes=None, strict=False + ) + options = torch_xla.stablehlo.StableHLOExportOptions( + override_tracing_arguments=sample_inputs + ) + stablehlo_model = torch_xla.stablehlo.exported_program_to_stablehlo( + exported, options + ) + state_dict_keys = list(stablehlo_model._bundle.state_dict.keys()) + + # Remove unused variables. + for k in state_dict_keys: + if "lifted" not in k: + stablehlo_model._bundle.state_dict.pop(k) + + bundle = copy.deepcopy(stablehlo_model._bundle) + bundle.state_dict = { + k: tf.Variable(v, trainable=False, name=k) + for k, v in bundle.state_dict.items() + } + bundle.additional_constants = [ + tf.Variable(v, trainable=False) for v in bundle.additional_constants + ] + + # Track variables in `bundle` for `write_out`. + self._tf_trackable.variables += ( + list(bundle.state_dict.values()) + bundle.additional_constants + ) + + # Ref: torch_xla.tf_saved_model_integration.save_stablehlo_graph_as_tf + def make_tf_function(func, bundle): + from tensorflow.compiler.tf2xla.python import xla as tfxla + + def _get_shape_with_dynamic(signature): + shape = copy.copy(signature.shape) + for i in signature.dynamic_dims: + shape[i] = None + return shape + + def _extract_call_parameters(args, meta, bundle): + call_args = [] + if meta.input_pytree_spec is not None: + args = tree.flatten(args) + for loc in meta.input_locations: + if loc.type_ == torch_xla.stablehlo.VariableType.PARAMETER: + call_args.append(bundle.state_dict[loc.name]) + elif loc.type_ == torch_xla.stablehlo.VariableType.CONSTANT: + call_args.append( + bundle.additional_constants[loc.position] + ) + else: + call_args.append(args[loc.position]) + return call_args + + def inner(*args): + Touts = [sig.dtype for sig in func.meta.output_signature] + Souts = [ + _get_shape_with_dynamic(sig) + for sig in func.meta.output_signature + ] + call_args = _extract_call_parameters(args, func.meta, bundle) + results = tfxla.call_module( + tuple(call_args), + version=5, + Tout=Touts, # dtype information + Sout=Souts, # Shape information + function_list=[], + module=func.bytecode, + ) + if len(Souts) == 1: + results = results[0] + return results + + return inner + + decorated_fn = tf.function( + make_tf_function( + stablehlo_model._bundle.stablehlo_funcs[0], bundle + ), + input_signature=input_signature, ) + return decorated_fn diff --git a/keras/src/callbacks/backup_and_restore.py b/keras/src/callbacks/backup_and_restore.py index a532c44a268..5e0b9524edb 100644 --- a/keras/src/callbacks/backup_and_restore.py +++ b/keras/src/callbacks/backup_and_restore.py @@ -37,6 +37,7 @@ class BackupAndRestore(Callback): >>> callback = keras.callbacks.BackupAndRestore(backup_dir="/tmp/backup") >>> model = keras.models.Sequential([keras.layers.Dense(10)]) >>> model.compile(keras.optimizers.SGD(), loss='mse') + >>> model.build(input_shape=(None, 20)) >>> try: ... model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10, ... batch_size=1, callbacks=[callback, InterruptingCallback()], diff --git a/keras/src/export/__init__.py b/keras/src/export/__init__.py index d9de43f685a..a51487812ea 100644 --- a/keras/src/export/__init__.py +++ b/keras/src/export/__init__.py @@ -1 +1,4 @@ -from keras.src.export.export_lib import ExportArchive +from keras.src.export.onnx import export_onnx +from keras.src.export.saved_model import ExportArchive +from keras.src.export.saved_model import export_saved_model +from keras.src.export.tfsm_layer import TFSMLayer diff --git a/keras/src/export/export_utils.py b/keras/src/export/export_utils.py new file mode 100644 index 00000000000..bfb66180f4b --- /dev/null +++ b/keras/src/export/export_utils.py @@ -0,0 +1,105 @@ +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import tree +from keras.src.utils.module_utils import tensorflow as tf + + +def get_input_signature(model): + if not isinstance(model, models.Model): + raise TypeError( + "The model must be a `keras.Model`. " + f"Received: model={model} of the type {type(model)}" + ) + if not model.built: + raise ValueError( + "The model provided has not yet been built. It must be built " + "before export." + ) + if isinstance(model, (models.Functional, models.Sequential)): + input_signature = tree.map_structure(make_input_spec, model.inputs) + if isinstance(input_signature, list) and len(input_signature) > 1: + input_signature = [input_signature] + else: + input_signature = _infer_input_signature_from_model(model) + if not input_signature or not model._called: + raise ValueError( + "The model provided has never called. " + "It must be called at least once before export." + ) + return input_signature + + +def _infer_input_signature_from_model(model): + shapes_dict = getattr(model, "_build_shapes_dict", None) + if not shapes_dict: + return None + + def _make_input_spec(structure): + # We need to turn wrapper structures like TrackingDict or _DictWrapper + # into plain Python structures because they don't work with jax2tf/JAX. + if isinstance(structure, dict): + return {k: _make_input_spec(v) for k, v in structure.items()} + elif isinstance(structure, tuple): + if all(isinstance(d, (int, type(None))) for d in structure): + return layers.InputSpec( + shape=(None,) + structure[1:], dtype=model.input_dtype + ) + return tuple(_make_input_spec(v) for v in structure) + elif isinstance(structure, list): + if all(isinstance(d, (int, type(None))) for d in structure): + return layers.InputSpec( + shape=[None] + structure[1:], dtype=model.input_dtype + ) + return [_make_input_spec(v) for v in structure] + else: + raise ValueError( + f"Unsupported type {type(structure)} for {structure}" + ) + + return [_make_input_spec(value) for value in shapes_dict.values()] + + +def make_input_spec(x): + if isinstance(x, layers.InputSpec): + if x.shape is None or x.dtype is None: + raise ValueError( + "The `shape` and `dtype` must be provided. " f"Received: x={x}" + ) + input_spec = x + elif isinstance(x, backend.KerasTensor): + shape = (None,) + backend.standardize_shape(x.shape)[1:] + dtype = backend.standardize_dtype(x.dtype) + input_spec = layers.InputSpec(dtype=dtype, shape=shape, name=x.name) + elif backend.is_tensor(x): + shape = (None,) + backend.standardize_shape(x.shape)[1:] + dtype = backend.standardize_dtype(x.dtype) + input_spec = layers.InputSpec(dtype=dtype, shape=shape, name=None) + else: + raise TypeError( + f"Unsupported x={x} of the type ({type(x)}). Supported types are: " + "`keras.InputSpec`, `keras.KerasTensor` and backend tensor." + ) + return input_spec + + +def make_tf_tensor_spec(x): + if isinstance(x, tf.TensorSpec): + tensor_spec = x + else: + input_spec = make_input_spec(x) + tensor_spec = tf.TensorSpec( + input_spec.shape, dtype=input_spec.dtype, name=input_spec.name + ) + return tensor_spec + + +def convert_spec_to_tensor(spec, replace_none_number=None): + shape = backend.standardize_shape(spec.shape) + if replace_none_number is not None: + replace_none_number = int(replace_none_number) + shape = tuple( + s if s is not None else replace_none_number for s in shape + ) + return ops.ones(shape, spec.dtype) diff --git a/keras/src/export/onnx.py b/keras/src/export/onnx.py new file mode 100644 index 00000000000..976c7a16247 --- /dev/null +++ b/keras/src/export/onnx.py @@ -0,0 +1,145 @@ +from keras.src import backend +from keras.src import tree +from keras.src.export.export_utils import convert_spec_to_tensor +from keras.src.export.export_utils import get_input_signature +from keras.src.export.export_utils import make_tf_tensor_spec +from keras.src.export.saved_model import DEFAULT_ENDPOINT_NAME +from keras.src.export.saved_model import ExportArchive +from keras.src.export.tf2onnx_lib import patch_tf2onnx + + +def export_onnx(model, filepath, verbose=True, input_signature=None, **kwargs): + """Export the model as a ONNX artifact for inference. + + This method lets you export a model to a lightweight ONNX artifact + that contains the model's forward pass only (its `call()` method) + and can be served via e.g. ONNX Runtime. + + The original code of the model (including any custom layers you may + have used) is *no longer* necessary to reload the artifact -- it is + entirely standalone. + + Args: + filepath: `str` or `pathlib.Path` object. The path to save the artifact. + verbose: `bool`. Whether to print a message during export. Defaults to + True`. + input_signature: Optional. Specifies the shape and dtype of the model + inputs. Can be a structure of `keras.InputSpec`, `tf.TensorSpec`, + `backend.KerasTensor`, or backend tensor. If not provided, it will + be automatically computed. Defaults to `None`. + **kwargs: Additional keyword arguments. + + **Note:** This feature is currently supported only with TensorFlow, JAX and + Torch backends. + + **Note:** The dtype policy must be "float32" for the model. You can further + optimize the ONNX artifact using the ONNX toolkit. Learn more here: + [https://onnxruntime.ai/docs/performance/](https://onnxruntime.ai/docs/performance/). + + **Note:** The dynamic shape feature is not yet supported with Torch + backend. As a result, you must fully define the shapes of the inputs using + `input_signature`. If `input_signature` is not provided, all instances of + `None` (such as the batch size) will be replaced with `1`. + + Example: + + ```python + # Export the model as a ONNX artifact + model.export("path/to/location", format="onnx") + + # Load the artifact in a different process/environment + ort_session = onnxruntime.InferenceSession("path/to/location") + ort_inputs = { + k.name: v for k, v in zip(ort_session.get_inputs(), input_data) + } + predictions = ort_session.run(None, ort_inputs) + ``` + """ + if input_signature is None: + input_signature = get_input_signature(model) + if not input_signature or not model._called: + raise ValueError( + "The model provided has never called. " + "It must be called at least once before export." + ) + + if backend.backend() in ("tensorflow", "jax"): + from keras.src.utils.module_utils import tf2onnx + + input_signature = tree.map_structure( + make_tf_tensor_spec, input_signature + ) + decorated_fn = get_concrete_fn(model, input_signature, **kwargs) + + # Use `tf2onnx` to convert the `decorated_fn` to the ONNX format. + patch_tf2onnx() # TODO: Remove this once `tf2onnx` supports numpy 2. + tf2onnx.convert.from_function( + decorated_fn, input_signature, output_path=filepath + ) + + elif backend.backend() == "torch": + import torch + + sample_inputs = tree.map_structure( + lambda x: convert_spec_to_tensor(x, replace_none_number=1), + input_signature, + ) + sample_inputs = tuple(sample_inputs) + # TODO: Make dict model exportable. + if any(isinstance(x, dict) for x in sample_inputs): + raise ValueError( + "Currently, `export_onnx` in the torch backend doesn't support " + "dictionaries as inputs." + ) + + # Convert to ONNX using TorchScript-based ONNX Exporter. + # TODO: Use TorchDynamo-based ONNX Exporter once + # `torch.onnx.dynamo_export()` supports Keras models. + torch.onnx.export(model, sample_inputs, filepath, verbose=verbose) + else: + raise NotImplementedError( + "`export_onnx` is only compatible with TensorFlow, JAX and " + "Torch backends." + ) + + +def _check_jax_kwargs(kwargs): + kwargs = kwargs.copy() + if "is_static" not in kwargs: + kwargs["is_static"] = True + if "jax2tf_kwargs" not in kwargs: + # TODO: These options will be deprecated in JAX. We need to + # find another way to export ONNX. + kwargs["jax2tf_kwargs"] = { + "enable_xla": False, + "native_serialization": False, + } + if kwargs["is_static"] is not True: + raise ValueError( + "`is_static` must be `True` in `kwargs` when using the jax " + "backend." + ) + if kwargs["jax2tf_kwargs"]["enable_xla"] is not False: + raise ValueError( + "`enable_xla` must be `False` in `kwargs['jax2tf_kwargs']` " + "when using the jax backend." + ) + if kwargs["jax2tf_kwargs"]["native_serialization"] is not False: + raise ValueError( + "`native_serialization` must be `False` in " + "`kwargs['jax2tf_kwargs']` when using the jax backend." + ) + return kwargs + + +def get_concrete_fn(model, input_signature, **kwargs): + """Get the `tf.function` associated with the model.""" + if backend.backend() == "jax": + kwargs = _check_jax_kwargs(kwargs) + export_archive = ExportArchive() + export_archive.track_and_add_endpoint( + DEFAULT_ENDPOINT_NAME, model, input_signature, **kwargs + ) + if backend.backend() == "tensorflow": + export_archive._filter_and_track_resources() + return export_archive._get_concrete_fn(DEFAULT_ENDPOINT_NAME) diff --git a/keras/src/export/onnx_test.py b/keras/src/export/onnx_test.py new file mode 100644 index 00000000000..2df09e3730f --- /dev/null +++ b/keras/src/export/onnx_test.py @@ -0,0 +1,216 @@ +"""Tests for ONNX exporting utilities.""" + +import os + +import numpy as np +import onnxruntime +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import testing +from keras.src import tree +from keras.src.export import onnx +from keras.src.saving import saving_lib +from keras.src.testing.test_utils import named_product + + +class CustomModel(models.Model): + def __init__(self, layer_list): + super().__init__() + self.layer_list = layer_list + + def call(self, input): + output = input + for layer in self.layer_list: + output = layer(output) + return output + + +def get_model(type="sequential", input_shape=(10,), layer_list=None): + layer_list = layer_list or [ + layers.Dense(10, activation="relu"), + layers.BatchNormalization(), + layers.Dense(1, activation="sigmoid"), + ] + if type == "sequential": + return models.Sequential(layer_list) + elif type == "functional": + input = output = tree.map_shape_structure(layers.Input, input_shape) + for layer in layer_list: + output = layer(output) + return models.Model(inputs=input, outputs=output) + elif type == "subclass": + return CustomModel(layer_list) + + +@pytest.mark.skipif( + backend.backend() not in ("tensorflow", "jax", "torch"), + reason=( + "`export_onnx` only currently supports the tensorflow, jax and torch " + "backends." + ), +) +@pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI") +class ExportONNXTest(testing.TestCase): + @parameterized.named_parameters( + named_product(model_type=["sequential", "functional", "subclass"]) + ) + def test_standard_model_export(self, model_type): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model(model_type) + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_output = model(ref_input) + + onnx.export_onnx(model, temp_filepath) + ort_session = onnxruntime.InferenceSession(temp_filepath) + ort_inputs = { + k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input]) + } + self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0]) + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + ort_inputs = { + k.name: v + for k, v in zip( + ort_session.get_inputs(), + [np.concatenate([ref_input, ref_input], axis=0)], + ) + } + ort_session.run(None, ort_inputs) + + @parameterized.named_parameters( + named_product(struct_type=["tuple", "array", "dict"]) + ) + def test_model_with_input_structure(self, struct_type): + if backend.backend() == "torch" and struct_type == "dict": + self.skipTest("The torch backend doesn't support the dict model.") + + class TupleModel(models.Model): + def call(self, inputs): + x, y = inputs + return ops.add(x, y) + + class ArrayModel(models.Model): + def call(self, inputs): + x = inputs[0] + y = inputs[1] + return ops.add(x, y) + + class DictModel(models.Model): + def call(self, inputs): + x = inputs["x"] + y = inputs["y"] + return ops.add(x, y) + + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") + if struct_type == "tuple": + model = TupleModel() + ref_input = (ref_input, ref_input * 2) + elif struct_type == "array": + model = ArrayModel() + ref_input = [ref_input, ref_input * 2] + elif struct_type == "dict": + model = DictModel() + ref_input = {"x": ref_input, "y": ref_input * 2} + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + ref_output = model(tree.map_structure(ops.convert_to_tensor, ref_input)) + + onnx.export_onnx(model, temp_filepath) + ort_session = onnxruntime.InferenceSession(temp_filepath) + if isinstance(ref_input, dict): + ort_inputs = { + k.name: v + for k, v in zip(ort_session.get_inputs(), ref_input.values()) + } + else: + ort_inputs = { + k.name: v for k, v in zip(ort_session.get_inputs(), ref_input) + } + self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0]) + + # Test with keras.saving_lib + temp_filepath = os.path.join( + self.get_temp_dir(), "exported_model.keras" + ) + saving_lib.save_model(model, temp_filepath) + revived_model = saving_lib.load_model( + temp_filepath, + { + "TupleModel": TupleModel, + "ArrayModel": ArrayModel, + "DictModel": DictModel, + }, + ) + self.assertAllClose(ref_output, revived_model(ref_input)) + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model2") + onnx.export_onnx(revived_model, temp_filepath) + + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + bigger_ref_input = tree.map_structure( + lambda x: np.concatenate([x, x], axis=0), ref_input + ) + if isinstance(bigger_ref_input, dict): + bigger_ort_inputs = { + k.name: v + for k, v in zip( + ort_session.get_inputs(), bigger_ref_input.values() + ) + } + else: + bigger_ort_inputs = { + k.name: v + for k, v in zip(ort_session.get_inputs(), bigger_ref_input) + } + ort_session.run(None, bigger_ort_inputs) + + def test_model_with_multiple_inputs(self): + class TwoInputsModel(models.Model): + def call(self, x, y): + return x + y + + def build(self, y_shape, x_shape): + self.built = True + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = TwoInputsModel() + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input_x = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_input_y = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_output = model(ref_input_x, ref_input_y) + + onnx.export_onnx(model, temp_filepath) + ort_session = onnxruntime.InferenceSession(temp_filepath) + ort_inputs = { + k.name: v + for k, v in zip( + ort_session.get_inputs(), [ref_input_x, ref_input_y] + ) + } + self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0]) + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + ort_inputs = { + k.name: v + for k, v in zip( + ort_session.get_inputs(), + [ + np.concatenate([ref_input_x, ref_input_x], axis=0), + np.concatenate([ref_input_y, ref_input_y], axis=0), + ], + ) + } + ort_session.run(None, ort_inputs) diff --git a/keras/src/export/export_lib.py b/keras/src/export/saved_model.py similarity index 71% rename from keras/src/export/export_lib.py rename to keras/src/export/saved_model.py index f91c0b9a609..1546e91aadf 100644 --- a/keras/src/export/export_lib.py +++ b/keras/src/export/saved_model.py @@ -1,11 +1,11 @@ -"""Library for exporting inference-only Keras models/layers.""" +"""Library for exporting SavedModel for Keras models/layers.""" from keras.src import backend from keras.src import layers from keras.src import tree from keras.src.api_export import keras_export -from keras.src.models import Functional -from keras.src.models import Sequential +from keras.src.export.export_utils import get_input_signature +from keras.src.export.export_utils import make_tf_tensor_spec from keras.src.utils import io_utils from keras.src.utils.module_utils import tensorflow as tf @@ -35,6 +35,9 @@ ) +DEFAULT_ENDPOINT_NAME = "serve" + + @keras_export("keras.export.ExportArchive") class ExportArchive(BackendExportArchive): """ExportArchive is used to write SavedModel artifacts (e.g. for inference). @@ -91,7 +94,7 @@ class ExportArchive(BackendExportArchive): **Note on resource tracking:** - `ExportArchive` is able to automatically track all `tf.Variables` used + `ExportArchive` is able to automatically track all `keras.Variables` used by its endpoints, so most of the time calling `.track(model)` is not strictly required. However, if your model uses lookup layers such as `IntegerLookup`, `StringLookup`, or `TextVectorization`, @@ -104,9 +107,10 @@ class ExportArchive(BackendExportArchive): def __init__(self): super().__init__() - if backend.backend() not in ("tensorflow", "jax"): + if backend.backend() not in ("tensorflow", "jax", "torch"): raise NotImplementedError( - "The export API is only compatible with JAX and TF backends." + "`ExportArchive` is only compatible with TensorFlow, JAX and " + "Torch backends." ) self._endpoint_names = [] @@ -141,8 +145,8 @@ def track(self, resource): (`TextVectorization`, `IntegerLookup`, `StringLookup`) are automatically tracked in `add_endpoint()`. - Arguments: - resource: A trackable TensorFlow resource. + Args: + resource: A trackable Keras resource, such as a layer or model. """ if isinstance(resource, layers.Layer) and not resource.built: raise ValueError( @@ -325,7 +329,9 @@ def serving_fn(x): self._endpoint_names.append(name) return decorated_fn - input_signature = tree.map_structure(_make_tensor_spec, input_signature) + input_signature = tree.map_structure( + make_tf_tensor_spec, input_signature + ) decorated_fn = BackendExportArchive.add_endpoint( self, name, fn, input_signature, **kwargs ) @@ -334,12 +340,80 @@ def serving_fn(x): self._endpoint_names.append(name) return decorated_fn + def track_and_add_endpoint(self, name, resource, input_signature, **kwargs): + """Track the variables and register a new serving endpoint. + + This function combines the functionality of `track` and `add_endpoint`. + It tracks the variables of the `resource` (either a layer or a model) + and registers a serving endpoint using `resource.__call__`. + + Args: + name: `str`. The name of the endpoint. + resource: A trackable Keras resource, such as a layer or model. + input_signature: Optional. Specifies the shape and dtype of `fn`. + Can be a structure of `keras.InputSpec`, `tf.TensorSpec`, + `backend.KerasTensor`, or backend tensor (see below for an + example showing a `Functional` model with 2 input arguments). If + not provided, `fn` must be a `tf.function` that has been called + at least once. Defaults to `None`. + **kwargs: Additional keyword arguments: + - Specific to the JAX backend: + - `is_static`: Optional `bool`. Indicates whether `fn` is + static. Set to `False` if `fn` involves state updates + (e.g., RNG seeds). + - `jax2tf_kwargs`: Optional `dict`. Arguments for + `jax2tf.convert`. See [`jax2tf.convert`]( + https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). + If `native_serialization` and `polymorphic_shapes` are + not provided, they are automatically computed. + + """ + if name in self._endpoint_names: + raise ValueError(f"Endpoint name '{name}' is already taken.") + if not isinstance(resource, layers.Layer): + raise ValueError( + "Invalid resource type. Expected an instance of a Keras " + "`Layer` or `Model`. " + f"Received: resource={resource} (of type {type(resource)})" + ) + if not resource.built: + raise ValueError( + "The layer provided has not yet been built. " + "It must be built before export." + ) + if backend.backend() != "jax": + if "jax2tf_kwargs" in kwargs or "is_static" in kwargs: + raise ValueError( + "'jax2tf_kwargs' and 'is_static' are only supported with " + f"the jax backend. Current backend: {backend.backend()}" + ) + + input_signature = tree.map_structure( + make_tf_tensor_spec, input_signature + ) + + if not hasattr(BackendExportArchive, "track_and_add_endpoint"): + # Default behavior. + self.track(resource) + return self.add_endpoint( + name, resource.__call__, input_signature, **kwargs + ) + else: + # Special case for the torch backend. + decorated_fn = BackendExportArchive.track_and_add_endpoint( + self, name, resource, input_signature, **kwargs + ) + self._endpoint_signatures[name] = input_signature + setattr(self._tf_trackable, name, decorated_fn) + self._endpoint_names.append(name) + return decorated_fn + def add_variable_collection(self, name, variables): """Register a set of variables to be retrieved after reloading. Arguments: name: The string name for the collection. - variables: A tuple/list/set of `tf.Variable` instances. + variables: A tuple/list/set of `keras.Variable` instances. Example: @@ -496,9 +570,6 @@ def export_saved_model( ): """Export the model as a TensorFlow SavedModel artifact for inference. - **Note:** This feature is currently supported only with TensorFlow and - JAX backends. - This method lets you export a model to a lightweight SavedModel artifact that contains the model's forward pass only (its `call()` method) and can be served via e.g. TensorFlow Serving. The forward pass is @@ -527,6 +598,14 @@ def export_saved_model( If `native_serialization` and `polymorphic_shapes` are not provided, they are automatically computed. + **Note:** This feature is currently supported only with TensorFlow, JAX and + Torch backends. Support for the Torch backend is experimental. + + **Note:** The dynamic shape feature is not yet supported with Torch + backend. As a result, you must fully define the shapes of the inputs using + `input_signature`. If `input_signature` is not provided, all instances of + `None` (such as the batch size) will be replaced with `1`. + Example: ```python @@ -543,218 +622,13 @@ def export_saved_model( `export()` method relies on `ExportArchive` internally. """ export_archive = ExportArchive() - export_archive.track(model) - if isinstance(model, (Functional, Sequential)): - if input_signature is None: - input_signature = tree.map_structure( - _make_tensor_spec, model.inputs - ) - if isinstance(input_signature, list) and len(input_signature) > 1: - input_signature = [input_signature] - export_archive.add_endpoint( - "serve", model.__call__, input_signature, **kwargs - ) - else: - if input_signature is None: - input_signature = _get_input_signature(model) - if not input_signature or not model._called: - raise ValueError( - "The model provided has never called. " - "It must be called at least once before export." - ) - export_archive.add_endpoint( - "serve", model.__call__, input_signature, **kwargs - ) - export_archive.write_out(filepath, verbose=verbose) - - -def _get_input_signature(model): - shapes_dict = getattr(model, "_build_shapes_dict", None) - if not shapes_dict: - return None - - def make_tensor_spec(structure): - # We need to turn wrapper structures like TrackingDict or _DictWrapper - # into plain Python structures because they don't work with jax2tf/JAX. - if isinstance(structure, dict): - return {k: make_tensor_spec(v) for k, v in structure.items()} - elif isinstance(structure, tuple): - if all(isinstance(d, (int, type(None))) for d in structure): - return tf.TensorSpec( - shape=(None,) + structure[1:], dtype=model.input_dtype - ) - return tuple(make_tensor_spec(v) for v in structure) - elif isinstance(structure, list): - if all(isinstance(d, (int, type(None))) for d in structure): - return tf.TensorSpec( - shape=[None] + structure[1:], dtype=model.input_dtype - ) - return [make_tensor_spec(v) for v in structure] - else: - raise ValueError( - f"Unsupported type {type(structure)} for {structure}" - ) - - return [make_tensor_spec(value) for value in shapes_dict.values()] - - -@keras_export("keras.layers.TFSMLayer") -class TFSMLayer(layers.Layer): - """Reload a Keras model/layer that was saved via SavedModel / ExportArchive. - - Arguments: - filepath: `str` or `pathlib.Path` object. The path to the SavedModel. - call_endpoint: Name of the endpoint to use as the `call()` method - of the reloaded layer. If the SavedModel was created - via `model.export()`, - then the default endpoint name is `'serve'`. In other cases - it may be named `'serving_default'`. - - Example: + if input_signature is None: + input_signature = get_input_signature(model) - ```python - model.export("path/to/artifact") - reloaded_layer = TFSMLayer("path/to/artifact") - outputs = reloaded_layer(inputs) - ``` - - The reloaded object can be used like a regular Keras layer, and supports - training/fine-tuning of its trainable weights. Note that the reloaded - object retains none of the internal structure or custom methods of the - original object -- it's a brand new layer created around the saved - function. - - **Limitations:** - - * Only call endpoints with a single `inputs` tensor argument - (which may optionally be a dict/tuple/list of tensors) are supported. - For endpoints with multiple separate input tensor arguments, consider - subclassing `TFSMLayer` and implementing a `call()` method with a - custom signature. - * If you need training-time behavior to differ from inference-time behavior - (i.e. if you need the reloaded object to support a `training=True` argument - in `__call__()`), make sure that the training-time call function is - saved as a standalone endpoint in the artifact, and provide its name - to the `TFSMLayer` via the `call_training_endpoint` argument. - """ - - def __init__( - self, - filepath, - call_endpoint="serve", - call_training_endpoint=None, - trainable=True, - name=None, - dtype=None, - ): - if backend.backend() != "tensorflow": - raise NotImplementedError( - "The TFSMLayer is only currently supported with the " - "TensorFlow backend." - ) - - # Initialize an empty layer, then add_weight() etc. as needed. - super().__init__(trainable=trainable, name=name, dtype=dtype) - - self._reloaded_obj = tf.saved_model.load(filepath) - - self.filepath = filepath - self.call_endpoint = call_endpoint - self.call_training_endpoint = call_training_endpoint - - # Resolve the call function. - if hasattr(self._reloaded_obj, call_endpoint): - # Case 1: it's set as an attribute. - self.call_endpoint_fn = getattr(self._reloaded_obj, call_endpoint) - elif call_endpoint in self._reloaded_obj.signatures: - # Case 2: it's listed in the `signatures` field. - self.call_endpoint_fn = self._reloaded_obj.signatures[call_endpoint] - else: - raise ValueError( - f"The endpoint '{call_endpoint}' " - "is neither an attribute of the reloaded SavedModel, " - "nor an entry in the `signatures` field of " - "the reloaded SavedModel. Select another endpoint via " - "the `call_endpoint` argument. Available endpoints for " - "this SavedModel: " - f"{list(self._reloaded_obj.signatures.keys())}" - ) - - # Resolving the training function. - if call_training_endpoint: - if hasattr(self._reloaded_obj, call_training_endpoint): - self.call_training_endpoint_fn = getattr( - self._reloaded_obj, call_training_endpoint - ) - elif call_training_endpoint in self._reloaded_obj.signatures: - self.call_training_endpoint_fn = self._reloaded_obj.signatures[ - call_training_endpoint - ] - else: - raise ValueError( - f"The endpoint '{call_training_endpoint}' " - "is neither an attribute of the reloaded SavedModel, " - "nor an entry in the `signatures` field of " - "the reloaded SavedModel. Available endpoints for " - "this SavedModel: " - f"{list(self._reloaded_obj.signatures.keys())}" - ) - - # Add trainable and non-trainable weights from the call_endpoint_fn. - all_fns = [self.call_endpoint_fn] - if call_training_endpoint: - all_fns.append(self.call_training_endpoint_fn) - tvs, ntvs = _list_variables_used_by_fns(all_fns) - for v in tvs: - self._add_existing_weight(v) - for v in ntvs: - self._add_existing_weight(v) - self.built = True - - def _add_existing_weight(self, weight): - """Tracks an existing weight.""" - self._track_variable(weight) - - def call(self, inputs, training=False, **kwargs): - if training: - if self.call_training_endpoint: - return self.call_training_endpoint_fn(inputs, **kwargs) - return self.call_endpoint_fn(inputs, **kwargs) - - def get_config(self): - base_config = super().get_config() - config = { - # Note: this is not intended to be portable. - "filepath": self.filepath, - "call_endpoint": self.call_endpoint, - "call_training_endpoint": self.call_training_endpoint, - } - return {**base_config, **config} - - -def _make_tensor_spec(x): - if isinstance(x, layers.InputSpec): - if x.shape is None or x.dtype is None: - raise ValueError( - "The `shape` and `dtype` must be provided. " f"Received: x={x}" - ) - tensor_spec = tf.TensorSpec(x.shape, dtype=x.dtype, name=x.name) - elif isinstance(x, tf.TensorSpec): - tensor_spec = x - elif isinstance(x, backend.KerasTensor): - shape = (None,) + backend.standardize_shape(x.shape)[1:] - tensor_spec = tf.TensorSpec(shape, dtype=x.dtype, name=x.name) - elif backend.is_tensor(x): - shape = (None,) + backend.standardize_shape(x.shape)[1:] - dtype = backend.standardize_dtype(x.dtype) - tensor_spec = tf.TensorSpec(shape, dtype=dtype, name=None) - else: - raise TypeError( - f"Unsupported x={x} of the type ({type(x)}). Supported types are: " - "`keras.InputSpec`, `tf.TensorSpec`, `keras.KerasTensor` and " - "backend tensor." - ) - return tensor_spec + export_archive.track_and_add_endpoint( + DEFAULT_ENDPOINT_NAME, model, input_signature, **kwargs + ) + export_archive.write_out(filepath, verbose=verbose) def _print_signature(fn, name, verbose=True): diff --git a/keras/src/export/export_lib_test.py b/keras/src/export/saved_model_test.py similarity index 80% rename from keras/src/export/export_lib_test.py rename to keras/src/export/saved_model_test.py index 040830934eb..c5ad6c58690 100644 --- a/keras/src/export/export_lib_test.py +++ b/keras/src/export/saved_model_test.py @@ -1,4 +1,4 @@ -"""Tests for inference-only model/layer exporting utilities.""" +"""Tests for SavedModel exporting utilities.""" import os @@ -14,8 +14,7 @@ from keras.src import random from keras.src import testing from keras.src import tree -from keras.src import utils -from keras.src.export import export_lib +from keras.src.export import saved_model from keras.src.saving import saving_lib from keras.src.testing.test_utils import named_product @@ -50,10 +49,16 @@ def get_model(type="sequential", input_shape=(10,), layer_list=None): @pytest.mark.skipif( - backend.backend() not in ("tensorflow", "jax"), - reason="Export only currently supports the TF and JAX backends.", + backend.backend() not in ("tensorflow", "jax", "torch"), + reason=( + "`export_saved_model` only currently supports the tensorflow, jax and " + "torch backends." + ), ) @pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI") +@pytest.mark.skipif( + testing.torch_uses_gpu(), reason="Leads to core dumps on CI" +) class ExportSavedModelTest(testing.TestCase): @parameterized.named_parameters( named_product(model_type=["sequential", "functional", "subclass"]) @@ -61,18 +66,29 @@ class ExportSavedModelTest(testing.TestCase): def test_standard_model_export(self, model_type): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") model = get_model(model_type) - ref_input = tf.random.normal((3, 10)) + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") ref_output = model(ref_input) - export_lib.export_saved_model(model, temp_filepath) + saved_model.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) self.assertAllClose(ref_output, revived_model.serve(ref_input)) # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return revived_model.serve(tf.random.normal((6, 10))) @parameterized.named_parameters( named_product(model_type=["sequential", "functional", "subclass"]) ) + @pytest.mark.skipif( + backend.backend() == "torch", + reason=( + "RuntimeError: mutating a non-functional tensor with a " + "functional tensor is not allowed in the torch backend." + ), + ) def test_model_with_rng_export(self, model_type): class RandomLayer(layers.Layer): def __init__(self): @@ -89,7 +105,7 @@ def call(self, inputs): ref_input = tf.random.normal((3, 10)) ref_output = model(ref_input) - export_lib.export_saved_model(model, temp_filepath) + saved_model.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) self.assertEqual(ref_output.shape, revived_model.serve(ref_input).shape) # Test with a different batch size @@ -102,6 +118,13 @@ def call(self, inputs): @parameterized.named_parameters( named_product(model_type=["sequential", "functional", "subclass"]) ) + @pytest.mark.skipif( + backend.backend() == "torch", + reason=( + "RuntimeError: mutating a non-functional tensor with a " + "functional tensor is not allowed in the torch backend." + ), + ) def test_model_with_non_trainable_state_export(self, model_type): class StateLayer(layers.Layer): def __init__(self): @@ -118,7 +141,7 @@ def call(self, inputs): model = get_model(model_type, layer_list=[StateLayer()]) model(tf.random.normal((3, 10))) - export_lib.export_saved_model(model, temp_filepath) + saved_model.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) # The non-trainable counter is expected to increment @@ -136,13 +159,17 @@ def call(self, inputs): def test_model_with_tf_data_layer(self, model_type): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") model = get_model(model_type, layer_list=[layers.Rescaling(scale=2.0)]) - ref_input = tf.random.normal((3, 10)) + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") ref_output = model(ref_input) - export_lib.export_saved_model(model, temp_filepath) + saved_model.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) self.assertAllClose(ref_output, revived_model.serve(ref_input)) # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return revived_model.serve(tf.random.normal((6, 10))) @parameterized.named_parameters( @@ -166,30 +193,24 @@ def call(self, inputs): y = inputs["y"] return ops.add(x, y) + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") if struct_type == "tuple": model = TupleModel() - ref_input = (tf.random.normal((3, 10)), tf.random.normal((3, 10))) + ref_input = (ref_input, ref_input * 2) elif struct_type == "array": model = ArrayModel() - ref_input = [tf.random.normal((3, 10)), tf.random.normal((3, 10))] + ref_input = [ref_input, ref_input * 2] elif struct_type == "dict": model = DictModel() - ref_input = { - "x": tf.random.normal((3, 10)), - "y": tf.random.normal((3, 10)), - } + ref_input = {"x": ref_input, "y": ref_input * 2} temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") ref_output = model(tree.map_structure(ops.convert_to_tensor, ref_input)) - export_lib.export_saved_model(model, temp_filepath) + saved_model.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) self.assertAllClose(ref_output, revived_model.serve(ref_input)) - # Test with a different batch size - bigger_input = tree.map_structure( - lambda x: tf.concat([x, x], axis=0), ref_input - ) - revived_model.serve(bigger_input) # Test with keras.saving_lib temp_filepath = os.path.join( @@ -205,7 +226,16 @@ def call(self, inputs): }, ) self.assertAllClose(ref_output, revived_model(ref_input)) - export_lib.export_saved_model(revived_model, self.get_temp_dir()) + saved_model.export_saved_model(revived_model, self.get_temp_dir()) + + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + bigger_input = tree.map_structure( + lambda x: tf.concat([x, x], axis=0), ref_input + ) + revived_model(bigger_input) def test_model_with_multiple_inputs(self): class TwoInputsModel(models.Model): @@ -217,16 +247,20 @@ def build(self, y_shape, x_shape): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") model = TwoInputsModel() - ref_input_x = tf.random.normal((3, 10)) - ref_input_y = tf.random.normal((3, 10)) + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input_x = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_input_y = np.random.normal(size=(batch_size, 10)).astype("float32") ref_output = model(ref_input_x, ref_input_y) - export_lib.export_saved_model(model, temp_filepath) + saved_model.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) self.assertAllClose( ref_output, revived_model.serve(ref_input_x, ref_input_y) ) # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return revived_model.serve( tf.random.normal((6, 10)), tf.random.normal((6, 10)) ) @@ -247,25 +281,28 @@ def build(self, y_shape, x_shape): def test_input_signature(self, model_type, input_signature): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") model = get_model(model_type) - ref_input = ops.random.uniform((3, 10)) + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input = ops.random.normal((batch_size, 10)) ref_output = model(ref_input) if input_signature == "backend_tensor": input_signature = (ref_input,) else: input_signature = (input_signature,) - export_lib.export_saved_model( + saved_model.export_saved_model( model, temp_filepath, input_signature=input_signature ) revived_model = tf.saved_model.load(temp_filepath) - self.assertAllClose(ref_output, revived_model.serve(ref_input)) + self.assertAllClose( + ref_output, revived_model.serve(ops.convert_to_numpy(ref_input)) + ) def test_input_signature_error(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") model = get_model("functional") with self.assertRaisesRegex(TypeError, "Unsupported x="): input_signature = (123,) - export_lib.export_saved_model( + saved_model.export_saved_model( model, temp_filepath, input_signature=input_signature ) @@ -289,7 +326,7 @@ def test_jax_specific_kwargs(self, model_type, is_static, jax2tf_kwargs): ref_input = ops.random.uniform((3, 10)) ref_output = model(ref_input) - export_lib.export_saved_model( + saved_model.export_saved_model( model, temp_filepath, is_static=is_static, @@ -300,10 +337,18 @@ def test_jax_specific_kwargs(self, model_type, is_static, jax2tf_kwargs): @pytest.mark.skipif( - backend.backend() not in ("tensorflow", "jax"), + backend.backend() + not in ( + "tensorflow", + "jax", + # "torch", # TODO: Support low-level operations in the torch backend. + ), reason="Export only currently supports the TF and JAX backends.", ) @pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI") +@pytest.mark.skipif( + testing.torch_uses_gpu(), reason="Leads to core dumps on CI" +) class ExportArchiveTest(testing.TestCase): @parameterized.named_parameters( named_product(model_type=["sequential", "functional", "subclass"]) @@ -316,13 +361,13 @@ def test_low_level_model_export(self, model_type): ref_output = model(ref_input) # Test variable tracking - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) self.assertLen(export_archive.variables, 8) self.assertLen(export_archive.trainable_variables, 6) self.assertLen(export_archive.non_trainable_variables, 2) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "call", @@ -342,7 +387,7 @@ def test_low_level_model_export_with_alias(self): ref_input = tf.random.normal((3, 10)) ref_output = model(ref_input) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) fn = export_archive.add_endpoint( "call", @@ -383,7 +428,7 @@ def call(self, inputs): ref_input = [tf.random.normal((3, 8)), tf.random.normal((3, 6))] ref_output = model(ref_input) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "call", @@ -414,7 +459,7 @@ def test_low_level_model_export_with_jax2tf_kwargs(self): ref_input = tf.random.normal((3, 10)) ref_output = model(ref_input) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "call", @@ -463,7 +508,7 @@ def call(self, inputs): # This will fail because the polymorphic_shapes that is # automatically generated will not account for the fact that # dynamic dimensions 1 and 2 must have the same value. - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "call", @@ -473,7 +518,7 @@ def call(self, inputs): ) export_archive.write_out(temp_filepath) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "call", @@ -497,7 +542,7 @@ def test_endpoint_registration_tf_function(self): ref_output = model(ref_input) # Test variable tracking - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) self.assertLen(export_archive.variables, 8) self.assertLen(export_archive.trainable_variables, 6) @@ -562,7 +607,7 @@ def infer_fn(x): # Export with TF inference function as endpoint temp_filepath = os.path.join(self.get_temp_dir(), "my_model") - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint("serve", infer_fn) export_archive.write_out(temp_filepath) @@ -637,7 +682,7 @@ def infer_fn(x): # Export with TF inference function as endpoint temp_filepath = os.path.join(self.get_temp_dir(), "my_model") - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint("serve", infer_fn) export_archive.write_out(temp_filepath) @@ -661,7 +706,7 @@ def test_layer_export(self): ref_input = tf.random.normal((3, 10)) ref_output = layer(ref_input) # Build layer (important) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(layer) export_archive.add_endpoint( "call", @@ -683,7 +728,7 @@ def test_multi_input_output_functional_model(self): ref_inputs = [tf.random.normal((3, 2)), tf.random.normal((3, 2))] ref_outputs = model(ref_inputs) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "serve", @@ -713,7 +758,7 @@ def test_multi_input_output_functional_model(self): } ref_outputs = model(ref_inputs) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "serve", @@ -753,7 +798,7 @@ def test_multi_input_output_functional_model(self): # ref_input = tf.convert_to_tensor(["one two three four"]) # ref_output = model(ref_input) - # export_lib.export_saved_model(model, temp_filepath) + # saved_model.export_saved_model(model, temp_filepath) # revived_model = tf.saved_model.load(temp_filepath) # self.assertAllClose(ref_output, revived_model.serve(ref_input)) @@ -766,7 +811,7 @@ def test_track_multiple_layers(self): ref_input_2 = tf.random.normal((3, 5)) ref_output_2 = layer_2(ref_input_2) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.add_endpoint( "call_1", layer_1.call, @@ -789,7 +834,7 @@ def test_non_standard_layer_signature(self): x1 = tf.random.normal((3, 2, 2)) x2 = tf.random.normal((3, 2, 2)) ref_output = layer(x1, x2) # Build layer (important) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(layer) export_archive.add_endpoint( "call", @@ -810,7 +855,7 @@ def test_non_standard_layer_signature_with_kwargs(self): x1 = tf.random.normal((3, 2, 2)) x2 = tf.random.normal((3, 2, 2)) ref_output = layer(x1, x2) # Build layer (important) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(layer) export_archive.add_endpoint( "call", @@ -840,7 +885,7 @@ def test_variable_collection(self): ) # Test variable tracking - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "call", @@ -862,13 +907,13 @@ def test_export_saved_model_errors(self): # Model has not been built model = models.Sequential([layers.Dense(2)]) with self.assertRaisesRegex(ValueError, "It must be built"): - export_lib.export_saved_model(model, temp_filepath) + saved_model.export_saved_model(model, temp_filepath) # Subclassed model has not been called model = get_model("subclass") model.build((2, 10)) with self.assertRaisesRegex(ValueError, "It must be called"): - export_lib.export_saved_model(model, temp_filepath) + saved_model.export_saved_model(model, temp_filepath) def test_export_archive_errors(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") @@ -876,7 +921,7 @@ def test_export_archive_errors(self): model(tf.random.normal((2, 3))) # Endpoint name reuse - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "call", @@ -893,18 +938,18 @@ def test_export_archive_errors(self): ) # Write out with no endpoints - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) with self.assertRaisesRegex(ValueError, "No endpoints have been set"): export_archive.write_out(temp_filepath) # Invalid object type with self.assertRaisesRegex(ValueError, "Invalid resource type"): - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track("model") # Set endpoint with no input signature - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) with self.assertRaisesRegex( ValueError, "you must provide an `input_signature`" @@ -912,14 +957,14 @@ def test_export_archive_errors(self): export_archive.add_endpoint("call", model.__call__) # Set endpoint that has never been called - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) @tf.function() def my_endpoint(x): return model(x) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) with self.assertRaisesRegex( ValueError, "you must either provide a function" @@ -932,7 +977,7 @@ def test_export_no_assets(self): # Case where there are legitimately no assets. model = models.Sequential([layers.Flatten()]) model(tf.random.normal((2, 3))) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.add_endpoint( "call", model.__call__, @@ -954,133 +999,3 @@ def test_model_export_method(self, model_type): self.assertAllClose(ref_output, revived_model.serve(ref_input)) # Test with a different batch size revived_model.serve(tf.random.normal((6, 10))) - - -@pytest.mark.skipif( - backend.backend() != "tensorflow", - reason="TFSM Layer reloading is only for the TF backend.", -) -class TestTFSMLayer(testing.TestCase): - def test_reloading_export_archive(self): - temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") - model = get_model() - ref_input = tf.random.normal((3, 10)) - ref_output = model(ref_input) - - export_lib.export_saved_model(model, temp_filepath) - reloaded_layer = export_lib.TFSMLayer(temp_filepath) - self.assertAllClose(reloaded_layer(ref_input), ref_output, atol=1e-7) - self.assertLen(reloaded_layer.weights, len(model.weights)) - self.assertLen( - reloaded_layer.trainable_weights, len(model.trainable_weights) - ) - self.assertLen( - reloaded_layer.non_trainable_weights, - len(model.non_trainable_weights), - ) - - # TODO(nkovela): Expand test coverage/debug fine-tuning and - # non-trainable use cases here. - - def test_reloading_default_saved_model(self): - temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") - model = get_model() - ref_input = tf.random.normal((3, 10)) - ref_output = model(ref_input) - - tf.saved_model.save(model, temp_filepath) - reloaded_layer = export_lib.TFSMLayer( - temp_filepath, call_endpoint="serving_default" - ) - # The output is a dict, due to the nature of SavedModel saving. - new_output = reloaded_layer(ref_input) - self.assertAllClose( - new_output[list(new_output.keys())[0]], - ref_output, - atol=1e-7, - ) - self.assertLen(reloaded_layer.weights, len(model.weights)) - self.assertLen( - reloaded_layer.trainable_weights, len(model.trainable_weights) - ) - self.assertLen( - reloaded_layer.non_trainable_weights, - len(model.non_trainable_weights), - ) - - def test_call_training(self): - temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") - utils.set_random_seed(1337) - model = models.Sequential( - [ - layers.Input((10,)), - layers.Dense(10), - layers.Dropout(0.99999), - ] - ) - export_archive = export_lib.ExportArchive() - export_archive.track(model) - export_archive.add_endpoint( - name="call_inference", - fn=lambda x: model(x, training=False), - input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], - ) - export_archive.add_endpoint( - name="call_training", - fn=lambda x: model(x, training=True), - input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], - ) - export_archive.write_out(temp_filepath) - reloaded_layer = export_lib.TFSMLayer( - temp_filepath, - call_endpoint="call_inference", - call_training_endpoint="call_training", - ) - inference_output = reloaded_layer( - tf.random.normal((1, 10)), training=False - ) - training_output = reloaded_layer( - tf.random.normal((1, 10)), training=True - ) - self.assertAllClose(np.mean(training_output), 0.0, atol=1e-7) - self.assertNotAllClose(np.mean(inference_output), 0.0, atol=1e-7) - - def test_serialization(self): - temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") - model = get_model() - ref_input = tf.random.normal((3, 10)) - ref_output = model(ref_input) - - export_lib.export_saved_model(model, temp_filepath) - reloaded_layer = export_lib.TFSMLayer(temp_filepath) - - # Test reinstantiation from config - config = reloaded_layer.get_config() - rereloaded_layer = export_lib.TFSMLayer.from_config(config) - self.assertAllClose(rereloaded_layer(ref_input), ref_output, atol=1e-7) - - # Test whole model saving with reloaded layer inside - model = models.Sequential([reloaded_layer]) - temp_model_filepath = os.path.join(self.get_temp_dir(), "m.keras") - model.save(temp_model_filepath, save_format="keras_v3") - reloaded_model = saving_lib.load_model( - temp_model_filepath, - custom_objects={"TFSMLayer": export_lib.TFSMLayer}, - ) - self.assertAllClose(reloaded_model(ref_input), ref_output, atol=1e-7) - - def test_errors(self): - # Test missing call endpoint - temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") - model = models.Sequential([layers.Input((2,)), layers.Dense(3)]) - export_lib.export_saved_model(model, temp_filepath) - with self.assertRaisesRegex(ValueError, "The endpoint 'wrong'"): - export_lib.TFSMLayer(temp_filepath, call_endpoint="wrong") - - # Test missing call training endpoint - with self.assertRaisesRegex(ValueError, "The endpoint 'wrong'"): - export_lib.TFSMLayer( - temp_filepath, - call_endpoint="serve", - call_training_endpoint="wrong", - ) diff --git a/keras/src/export/tf2onnx_lib.py b/keras/src/export/tf2onnx_lib.py new file mode 100644 index 00000000000..8dee72b2af4 --- /dev/null +++ b/keras/src/export/tf2onnx_lib.py @@ -0,0 +1,180 @@ +import copy +import functools +import logging +import traceback + +import numpy as np + + +@functools.lru_cache() +def patch_tf2onnx(): + """Patches `tf2onnx` to ensure compatibility with numpy>=2.0.0.""" + + from onnx import AttributeProto + from onnx import TensorProto + + from keras.src.utils.module_utils import tf2onnx + + logger = logging.getLogger(tf2onnx.__name__) + + def patched_rewrite_constant_fold(g, ops): + """ + We call tensorflow transform with constant folding but in some cases + tensorflow does fold all constants. Since there are a bunch of ops in + onnx that use attributes where tensorflow has dynamic inputs, we badly + want constant folding to work. For cases where tensorflow missed + something, make another pass over the graph and fix want we care about. + """ + func_map = { + "Add": np.add, + "GreaterEqual": np.greater_equal, + "Cast": np.asarray, + "ConcatV2": np.concatenate, + "Less": np.less, + "ListDiff": np.setdiff1d, + "Mul": np.multiply, + "Pack": np.stack, + "Range": np.arange, + "Sqrt": np.sqrt, + "Sub": np.subtract, + } + ops = list(ops) + + keep_looking = True + while keep_looking: + keep_looking = False + for idx, op in enumerate(ops): + func = func_map.get(op.type) + if func is None: + continue + if set(op.output) & set(g.outputs): + continue + try: + inputs = [] + for node in op.inputs: + if not node.is_const(): + break + inputs.append(node.get_tensor_value(as_list=False)) + + logger.debug( + "op name %s, %s, %s", + op.name, + len(op.input), + len(inputs), + ) + if inputs and len(op.input) == len(inputs): + logger.info( + "folding node type=%s, name=%s" % (op.type, op.name) + ) + if op.type == "Cast": + dst = op.get_attr_int("to") + np_type = tf2onnx.utils.map_onnx_to_numpy_type(dst) + val = np.asarray(*inputs, dtype=np_type) + elif op.type == "ConcatV2": + axis = inputs[-1] + values = inputs[:-1] + val = func(tuple(values), axis) + elif op.type == "ListDiff": + out_type = op.get_attr_int("out_idx") + np_type = tf2onnx.utils.map_onnx_to_numpy_type( + out_type + ) + val = func(*inputs) + val = val.astype(np_type) + elif op.type in ["Pack"]: + # handle ops that need input array and axis + axis = op.get_attr_int("axis") + val = func(inputs, axis=axis) + elif op.type == "Range": + dtype = op.get_attr_int("Tidx") + np_type = tf2onnx.utils.map_onnx_to_numpy_type( + dtype + ) + val = func(*inputs, dtype=np_type) + else: + val = func(*inputs) + + new_node_name = tf2onnx.utils.make_name(op.name) + new_output_name = new_node_name + old_output_name = op.output[0] + old_node_name = op.name + logger.debug( + "create const node [%s] replacing [%s]", + new_node_name, + old_node_name, + ) + ops[idx] = g.make_const(new_node_name, val) + + logger.debug( + "replace old output [%s] with new output [%s]", + old_output_name, + new_output_name, + ) + # need to re-write the consumers input name to use the + # const name + consumers = g.find_output_consumers(old_output_name) + if consumers: + for consumer in consumers: + g.replace_input( + consumer, old_output_name, new_output_name + ) + + # keep looking until there is nothing we can fold. + # We keep the graph in topological order so if we + # folded, the result might help a following op. + keep_looking = True + except Exception as ex: + tb = traceback.format_exc() + logger.info("exception: %s, details: %s", ex, tb) + # ignore errors + + return ops + + def patched_get_value_attr(self, external_tensor_storage=None): + """ + Return onnx attr for value property of node. + Attr is modified to point to external tensor data stored in + external_tensor_storage, if included. + """ + a = self._attr["value"] + if ( + external_tensor_storage is not None + and self in external_tensor_storage.node_to_modified_value_attr + ): + return external_tensor_storage.node_to_modified_value_attr[self] + if external_tensor_storage is None or a.type != AttributeProto.TENSOR: + return a + + def prod(x): + if hasattr(np, "product"): + return np.product(x) + else: + return np.prod(x) + + if ( + prod(a.t.dims) + > external_tensor_storage.external_tensor_size_threshold + ): + a = copy.deepcopy(a) + tensor_name = ( + self.name.strip() + + "_" + + str(external_tensor_storage.name_counter) + ) + for c in '~"#%&*:<>?/\\{|}': + tensor_name = tensor_name.replace(c, "_") + external_tensor_storage.name_counter += 1 + external_tensor_storage.name_to_tensor_data[tensor_name] = ( + a.t.raw_data + ) + external_tensor_storage.node_to_modified_value_attr[self] = a + a.t.raw_data = b"" + a.t.ClearField("raw_data") + location = a.t.external_data.add() + location.key = "location" + location.value = tensor_name + a.t.data_location = TensorProto.EXTERNAL + return a + + tf2onnx.tfonnx.rewrite_constant_fold = patched_rewrite_constant_fold + tf2onnx.graph.Node.get_value_attr = patched_get_value_attr diff --git a/keras/src/export/tfsm_layer.py b/keras/src/export/tfsm_layer.py new file mode 100644 index 00000000000..61859bf0fc2 --- /dev/null +++ b/keras/src/export/tfsm_layer.py @@ -0,0 +1,139 @@ +from keras.src import backend +from keras.src import layers +from keras.src.api_export import keras_export +from keras.src.export.saved_model import _list_variables_used_by_fns +from keras.src.utils.module_utils import tensorflow as tf + + +@keras_export("keras.layers.TFSMLayer") +class TFSMLayer(layers.Layer): + """Reload a Keras model/layer that was saved via SavedModel / ExportArchive. + + Arguments: + filepath: `str` or `pathlib.Path` object. The path to the SavedModel. + call_endpoint: Name of the endpoint to use as the `call()` method + of the reloaded layer. If the SavedModel was created + via `model.export()`, + then the default endpoint name is `'serve'`. In other cases + it may be named `'serving_default'`. + + Example: + + ```python + model.export("path/to/artifact") + reloaded_layer = TFSMLayer("path/to/artifact") + outputs = reloaded_layer(inputs) + ``` + + The reloaded object can be used like a regular Keras layer, and supports + training/fine-tuning of its trainable weights. Note that the reloaded + object retains none of the internal structure or custom methods of the + original object -- it's a brand new layer created around the saved + function. + + **Limitations:** + + * Only call endpoints with a single `inputs` tensor argument + (which may optionally be a dict/tuple/list of tensors) are supported. + For endpoints with multiple separate input tensor arguments, consider + subclassing `TFSMLayer` and implementing a `call()` method with a + custom signature. + * If you need training-time behavior to differ from inference-time behavior + (i.e. if you need the reloaded object to support a `training=True` argument + in `__call__()`), make sure that the training-time call function is + saved as a standalone endpoint in the artifact, and provide its name + to the `TFSMLayer` via the `call_training_endpoint` argument. + """ + + def __init__( + self, + filepath, + call_endpoint="serve", + call_training_endpoint=None, + trainable=True, + name=None, + dtype=None, + ): + if backend.backend() != "tensorflow": + raise NotImplementedError( + "The TFSMLayer is only currently supported with the " + "TensorFlow backend." + ) + + # Initialize an empty layer, then add_weight() etc. as needed. + super().__init__(trainable=trainable, name=name, dtype=dtype) + + self._reloaded_obj = tf.saved_model.load(filepath) + + self.filepath = filepath + self.call_endpoint = call_endpoint + self.call_training_endpoint = call_training_endpoint + + # Resolve the call function. + if hasattr(self._reloaded_obj, call_endpoint): + # Case 1: it's set as an attribute. + self.call_endpoint_fn = getattr(self._reloaded_obj, call_endpoint) + elif call_endpoint in self._reloaded_obj.signatures: + # Case 2: it's listed in the `signatures` field. + self.call_endpoint_fn = self._reloaded_obj.signatures[call_endpoint] + else: + raise ValueError( + f"The endpoint '{call_endpoint}' " + "is neither an attribute of the reloaded SavedModel, " + "nor an entry in the `signatures` field of " + "the reloaded SavedModel. Select another endpoint via " + "the `call_endpoint` argument. Available endpoints for " + "this SavedModel: " + f"{list(self._reloaded_obj.signatures.keys())}" + ) + + # Resolving the training function. + if call_training_endpoint: + if hasattr(self._reloaded_obj, call_training_endpoint): + self.call_training_endpoint_fn = getattr( + self._reloaded_obj, call_training_endpoint + ) + elif call_training_endpoint in self._reloaded_obj.signatures: + self.call_training_endpoint_fn = self._reloaded_obj.signatures[ + call_training_endpoint + ] + else: + raise ValueError( + f"The endpoint '{call_training_endpoint}' " + "is neither an attribute of the reloaded SavedModel, " + "nor an entry in the `signatures` field of " + "the reloaded SavedModel. Available endpoints for " + "this SavedModel: " + f"{list(self._reloaded_obj.signatures.keys())}" + ) + + # Add trainable and non-trainable weights from the call_endpoint_fn. + all_fns = [self.call_endpoint_fn] + if call_training_endpoint: + all_fns.append(self.call_training_endpoint_fn) + tvs, ntvs = _list_variables_used_by_fns(all_fns) + for v in tvs: + self._add_existing_weight(v) + for v in ntvs: + self._add_existing_weight(v) + self.built = True + + def _add_existing_weight(self, weight): + """Tracks an existing weight.""" + self._track_variable(weight) + + def call(self, inputs, training=False, **kwargs): + if training: + if self.call_training_endpoint: + return self.call_training_endpoint_fn(inputs, **kwargs) + return self.call_endpoint_fn(inputs, **kwargs) + + def get_config(self): + base_config = super().get_config() + config = { + # Note: this is not intended to be portable. + "filepath": self.filepath, + "call_endpoint": self.call_endpoint, + "call_training_endpoint": self.call_training_endpoint, + } + return {**base_config, **config} diff --git a/keras/src/export/tfsm_layer_test.py b/keras/src/export/tfsm_layer_test.py new file mode 100644 index 00000000000..31cb1673cf1 --- /dev/null +++ b/keras/src/export/tfsm_layer_test.py @@ -0,0 +1,142 @@ +import os + +import numpy as np +import pytest +import tensorflow as tf + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import testing +from keras.src import utils +from keras.src.export import saved_model +from keras.src.export import tfsm_layer +from keras.src.export.saved_model_test import get_model +from keras.src.saving import saving_lib + + +@pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="TFSM Layer reloading is only for the TF backend.", +) +class TestTFSMLayer(testing.TestCase): + def test_reloading_export_archive(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input) + + saved_model.export_saved_model(model, temp_filepath) + reloaded_layer = tfsm_layer.TFSMLayer(temp_filepath) + self.assertAllClose(reloaded_layer(ref_input), ref_output, atol=1e-7) + self.assertLen(reloaded_layer.weights, len(model.weights)) + self.assertLen( + reloaded_layer.trainable_weights, len(model.trainable_weights) + ) + self.assertLen( + reloaded_layer.non_trainable_weights, + len(model.non_trainable_weights), + ) + + def test_reloading_default_saved_model(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input) + + tf.saved_model.save(model, temp_filepath) + reloaded_layer = tfsm_layer.TFSMLayer( + temp_filepath, call_endpoint="serving_default" + ) + # The output is a dict, due to the nature of SavedModel saving. + new_output = reloaded_layer(ref_input) + self.assertAllClose( + new_output[list(new_output.keys())[0]], + ref_output, + atol=1e-7, + ) + self.assertLen(reloaded_layer.weights, len(model.weights)) + self.assertLen( + reloaded_layer.trainable_weights, len(model.trainable_weights) + ) + self.assertLen( + reloaded_layer.non_trainable_weights, + len(model.non_trainable_weights), + ) + + def test_call_training(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + utils.set_random_seed(1337) + model = models.Sequential( + [ + layers.Input((10,)), + layers.Dense(10), + layers.Dropout(0.99999), + ] + ) + export_archive = saved_model.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + name="call_inference", + fn=lambda x: model(x, training=False), + input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + ) + export_archive.add_endpoint( + name="call_training", + fn=lambda x: model(x, training=True), + input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + ) + export_archive.write_out(temp_filepath) + reloaded_layer = tfsm_layer.TFSMLayer( + temp_filepath, + call_endpoint="call_inference", + call_training_endpoint="call_training", + ) + inference_output = reloaded_layer( + tf.random.normal((1, 10)), training=False + ) + training_output = reloaded_layer( + tf.random.normal((1, 10)), training=True + ) + self.assertAllClose(np.mean(training_output), 0.0, atol=1e-7) + self.assertNotAllClose(np.mean(inference_output), 0.0, atol=1e-7) + + def test_serialization(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input) + + saved_model.export_saved_model(model, temp_filepath) + reloaded_layer = tfsm_layer.TFSMLayer(temp_filepath) + + # Test reinstantiation from config + config = reloaded_layer.get_config() + rereloaded_layer = tfsm_layer.TFSMLayer.from_config(config) + self.assertAllClose(rereloaded_layer(ref_input), ref_output, atol=1e-7) + + # Test whole model saving with reloaded layer inside + model = models.Sequential([reloaded_layer]) + temp_model_filepath = os.path.join(self.get_temp_dir(), "m.keras") + model.save(temp_model_filepath, save_format="keras_v3") + reloaded_model = saving_lib.load_model( + temp_model_filepath, + custom_objects={"TFSMLayer": tfsm_layer.TFSMLayer}, + ) + self.assertAllClose(reloaded_model(ref_input), ref_output, atol=1e-7) + + def test_errors(self): + # Test missing call endpoint + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = models.Sequential([layers.Input((2,)), layers.Dense(3)]) + saved_model.export_saved_model(model, temp_filepath) + with self.assertRaisesRegex(ValueError, "The endpoint 'wrong'"): + tfsm_layer.TFSMLayer(temp_filepath, call_endpoint="wrong") + + # Test missing call training endpoint + with self.assertRaisesRegex(ValueError, "The endpoint 'wrong'"): + tfsm_layer.TFSMLayer( + temp_filepath, + call_endpoint="serve", + call_training_endpoint="wrong", + ) diff --git a/keras/src/layers/__init__.py b/keras/src/layers/__init__.py index 584a3cdc1f4..f9719bfe442 100644 --- a/keras/src/layers/__init__.py +++ b/keras/src/layers/__init__.py @@ -99,6 +99,12 @@ from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( RandomBrightness, ) +from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import ( + RandomColorDegeneration, +) +from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import ( + RandomColorJitter, +) from keras.src.layers.preprocessing.image_preprocessing.random_contrast import ( RandomContrast, ) @@ -114,12 +120,21 @@ from keras.src.layers.preprocessing.image_preprocessing.random_hue import ( RandomHue, ) +from keras.src.layers.preprocessing.image_preprocessing.random_posterization import ( + RandomPosterization, +) from keras.src.layers.preprocessing.image_preprocessing.random_rotation import ( RandomRotation, ) from keras.src.layers.preprocessing.image_preprocessing.random_saturation import ( RandomSaturation, ) +from keras.src.layers.preprocessing.image_preprocessing.random_sharpness import ( + RandomSharpness, +) +from keras.src.layers.preprocessing.image_preprocessing.random_shear import ( + RandomShear, +) from keras.src.layers.preprocessing.image_preprocessing.random_translation import ( RandomTranslation, ) diff --git a/keras/src/layers/attention/attention.py b/keras/src/layers/attention/attention.py index 592468fe802..d336781c8b3 100644 --- a/keras/src/layers/attention/attention.py +++ b/keras/src/layers/attention/attention.py @@ -1,6 +1,7 @@ from keras.src import backend from keras.src import ops from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor from keras.src.layers.layer import Layer @@ -84,6 +85,8 @@ def __init__( f"Received: score_mode={score_mode}" ) + self._return_attention_scores = False + def build(self, input_shape): self._validate_inputs(input_shape) self.scale = None @@ -217,6 +220,7 @@ def call( use_causal_mask=False, ): self._validate_inputs(inputs=inputs, mask=mask) + self._return_attention_scores = return_attention_scores q = inputs[0] v = inputs[1] k = inputs[2] if len(inputs) > 2 else v @@ -226,16 +230,17 @@ def call( scores_mask = self._calculate_score_mask( scores, v_mask, use_causal_mask ) - result, attention_scores = self._apply_scores( + attention_output, attention_scores = self._apply_scores( scores=scores, value=v, scores_mask=scores_mask, training=training ) if q_mask is not None: # Mask of shape [batch_size, Tq, 1]. q_mask = ops.expand_dims(q_mask, axis=-1) - result *= ops.cast(q_mask, dtype=result.dtype) + attention_output *= ops.cast(q_mask, dtype=attention_output.dtype) if return_attention_scores: - return result, attention_scores - return result + return (attention_output, attention_scores) + else: + return attention_output def compute_mask(self, inputs, mask=None): self._validate_inputs(inputs=inputs, mask=mask) @@ -244,8 +249,49 @@ def compute_mask(self, inputs, mask=None): return ops.convert_to_tensor(mask[0]) def compute_output_shape(self, input_shape): - """Returns shape of value tensor dim, but for query tensor length""" - return (*input_shape[0][:-1], input_shape[1][-1]) + query_shape, value_shape, key_shape = input_shape + if key_shape is None: + key_shape = value_shape + + output_shape = (*query_shape[:-1], value_shape[-1]) + if self._return_attention_scores: + scores_shape = (query_shape[0], query_shape[1], key_shape[1]) + return output_shape, scores_shape + return output_shape + + def compute_output_spec( + self, + inputs, + mask=None, + return_attention_scores=False, + training=None, + use_causal_mask=False, + ): + # Validate and unpack inputs + self._validate_inputs(inputs, mask) + query = inputs[0] + value = inputs[1] + key = inputs[2] if len(inputs) > 2 else value + + # Compute primary output shape + output_shape = self.compute_output_shape( + [query.shape, value.shape, key.shape] + ) + output_spec = KerasTensor(output_shape, dtype=self.compute_dtype) + + # Handle attention scores if requested + if self._return_attention_scores or return_attention_scores: + scores_shape = ( + query.shape[0], + query.shape[1], + key.shape[1], + ) # (batch_size, Tq, Tv) + attention_scores_spec = KerasTensor( + scores_shape, dtype=self.compute_dtype + ) + return (output_spec, attention_scores_spec) + + return output_spec def _validate_inputs(self, inputs, mask=None): """Validates arguments of the call method.""" diff --git a/keras/src/layers/attention/attention_test.py b/keras/src/layers/attention/attention_test.py index de8dba64340..88598d72112 100644 --- a/keras/src/layers/attention/attention_test.py +++ b/keras/src/layers/attention/attention_test.py @@ -358,3 +358,74 @@ def test_attention_compute_output_shape(self): ), output.shape, ) + + def test_return_attention_scores_true(self): + """Test that the layer returns attention scores along with outputs.""" + # Generate dummy input data + query = np.random.random((2, 8, 16)).astype(np.float32) + value = np.random.random((2, 4, 16)).astype(np.float32) + + # Initialize the Attention layer + layer = layers.Attention() + + # Call the layer with return_attention_scores=True + output, attention_scores = layer( + [query, value], return_attention_scores=True + ) + + # Check the shape of the outputs + self.assertEqual(output.shape, (2, 8, 16)) # Output shape + self.assertEqual( + attention_scores.shape, (2, 8, 4) + ) # Attention scores shape + + def test_return_attention_scores_true_and_tuple(self): + """Test that the layer outputs are a tuple when + return_attention_scores=True.""" + # Generate dummy input data + query = np.random.random((2, 8, 16)).astype(np.float32) + value = np.random.random((2, 4, 16)).astype(np.float32) + + # Initialize the Attention layer + layer = layers.Attention() + + # Call the layer with return_attention_scores=True + outputs = layer([query, value], return_attention_scores=True) + + # Check that outputs is a tuple + self.assertIsInstance( + outputs, tuple, "Expected the outputs to be a tuple" + ) + + def test_return_attention_scores_true_tuple_then_unpack(self): + """Test that outputs can be unpacked correctly.""" + # Generate dummy input data + query = np.random.random((2, 8, 16)).astype(np.float32) + value = np.random.random((2, 4, 16)).astype(np.float32) + + # Initialize the Attention layer + layer = layers.Attention() + + # Call the layer with return_attention_scores=True + outputs = layer([query, value], return_attention_scores=True) + + # Unpack the outputs + output, attention_scores = outputs + + # Check the shape of the unpacked outputs + self.assertEqual(output.shape, (2, 8, 16)) # Output shape + self.assertEqual( + attention_scores.shape, (2, 8, 4) + ) # Attention scores shape + + def test_return_attention_scores_with_symbolic_tensors(self): + """Test to check outputs with symbolic tensors with + return_attention_scores = True""" + attention = layers.Attention() + x = layers.Input(shape=(3, 5)) + y = layers.Input(shape=(4, 5)) + output, attention_scores = attention( + [x, y], return_attention_scores=True + ) + self.assertEqual(output.shape, (None, 3, 5)) # Output shape + self.assertEqual(attention_scores.shape, (None, 3, 4)) diff --git a/keras/src/layers/core/dense_test.py b/keras/src/layers/core/dense_test.py index 2c2faac218a..b54c91c9e19 100644 --- a/keras/src/layers/core/dense_test.py +++ b/keras/src/layers/core/dense_test.py @@ -6,6 +6,7 @@ from keras.src import backend from keras.src import constraints +from keras.src import export from keras.src import layers from keras.src import models from keras.src import ops @@ -14,7 +15,6 @@ from keras.src import saving from keras.src import testing from keras.src.backend.common import keras_tensor -from keras.src.export import export_lib class DenseTest(testing.TestCase): @@ -566,7 +566,7 @@ def test_quantize_int8_when_lora_enabled(self): ref_input = tf.random.normal((2, 8)) ref_output = model(ref_input) model.export(temp_filepath, format="tf_saved_model") - reloaded_layer = export_lib.TFSMLayer(temp_filepath) + reloaded_layer = export.TFSMLayer(temp_filepath) self.assertAllClose( reloaded_layer(ref_input), ref_output, atol=1e-7 ) @@ -738,7 +738,7 @@ def test_quantize_float8_fitting(self): ref_input = tf.random.normal((2, 8)) ref_output = model(ref_input) model.export(temp_filepath, format="tf_saved_model") - reloaded_layer = export_lib.TFSMLayer(temp_filepath) + reloaded_layer = export.TFSMLayer(temp_filepath) self.assertAllClose(reloaded_layer(ref_input), ref_output) self.assertLen(reloaded_layer.weights, len(model.weights)) self.assertLen( diff --git a/keras/src/layers/core/einsum_dense_test.py b/keras/src/layers/core/einsum_dense_test.py index 796cb37fd76..3fcecef0310 100644 --- a/keras/src/layers/core/einsum_dense_test.py +++ b/keras/src/layers/core/einsum_dense_test.py @@ -6,6 +6,7 @@ from keras.src import backend from keras.src import constraints +from keras.src import export from keras.src import layers from keras.src import models from keras.src import ops @@ -13,7 +14,6 @@ from keras.src import random from keras.src import saving from keras.src import testing -from keras.src.export import export_lib class EinsumDenseTest(testing.TestCase): @@ -699,7 +699,7 @@ def test_quantize_int8_when_lora_enabled( ref_input = tf.random.normal(input_shape) ref_output = model(ref_input) model.export(temp_filepath, format="tf_saved_model") - reloaded_layer = export_lib.TFSMLayer(temp_filepath) + reloaded_layer = export.TFSMLayer(temp_filepath) self.assertAllClose( reloaded_layer(ref_input), ref_output, atol=1e-7 ) @@ -878,7 +878,7 @@ def test_quantize_float8_fitting(self): ref_input = tf.random.normal((2, 3)) ref_output = model(ref_input) model.export(temp_filepath, format="tf_saved_model") - reloaded_layer = export_lib.TFSMLayer(temp_filepath) + reloaded_layer = export.TFSMLayer(temp_filepath) self.assertAllClose(reloaded_layer(ref_input), ref_output) self.assertLen(reloaded_layer.weights, len(model.weights)) self.assertLen( diff --git a/keras/src/layers/core/embedding_test.py b/keras/src/layers/core/embedding_test.py index ac4b6d6c8c7..784216c4cc8 100644 --- a/keras/src/layers/core/embedding_test.py +++ b/keras/src/layers/core/embedding_test.py @@ -6,11 +6,11 @@ from keras.src import backend from keras.src import constraints +from keras.src import export from keras.src import layers from keras.src import models from keras.src import ops from keras.src import saving -from keras.src.export import export_lib from keras.src.testing import test_case @@ -439,7 +439,7 @@ def test_quantize_when_lora_enabled(self): ref_input = tf.random.normal((32, 3)) ref_output = model(ref_input) model.export(temp_filepath, format="tf_saved_model") - reloaded_layer = export_lib.TFSMLayer(temp_filepath) + reloaded_layer = export.TFSMLayer(temp_filepath) self.assertAllClose( reloaded_layer(ref_input), ref_output, atol=1e-7 ) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 2508153d23c..8e36bb20456 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -493,7 +493,7 @@ def add_weight( autocast=True, regularizer=None, constraint=None, - aggregation="mean", + aggregation="none", name=None, ): """Add a weight variable to the layer. @@ -520,10 +520,11 @@ def add_weight( constraint: Contrainst object to call on the variable after any optimizer update, or string name of a built-in constraint. Defaults to `None`. - aggregation: String, one of `'mean'`, `'sum'`, - `'only_first_replica'`. Annotates the variable with the type - of multi-replica aggregation to be used for this variable - when writing custom data parallel training loops. + aggregation: Optional string, one of `None`, `"none"`, `"mean"`, + `"sum"` or `"only_first_replica"`. Annotates the variable with + the type of multi-replica aggregation to be used for this + variable when writing custom data parallel training loops. + Defaults to `"none"`. name: String name of the variable. Useful for debugging purposes. """ self._check_super_called() @@ -1471,9 +1472,19 @@ def _check_super_called(self): def _assert_input_compatibility(self, arg_0): if self.input_spec: - input_spec.assert_input_compatibility( - self.input_spec, arg_0, layer_name=self.name - ) + try: + input_spec.assert_input_compatibility( + self.input_spec, arg_0, layer_name=self.name + ) + except SystemError: + if backend.backend() == "torch": + # TODO: The torch backend failed the ONNX CI with the error: + # SystemError: returned a result with an exception set + # As a workaround, we are skipping this for now. + pass + else: + raise def _get_call_context(self): """Returns currently active `CallContext`.""" diff --git a/keras/src/layers/preprocessing/image_preprocessing/equalization.py b/keras/src/layers/preprocessing/image_preprocessing/equalization.py index e58f4b254c0..555713bf854 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/equalization.py +++ b/keras/src/layers/preprocessing/image_preprocessing/equalization.py @@ -170,25 +170,31 @@ def _apply_equalization(self, channel, hist): ) return self.backend.numpy.take(lookup_table, indices) - def transform_images(self, images, transformations=None, **kwargs): - images = self.backend.cast(images, self.compute_dtype) - - if self.data_format == "channels_first": - channels = [] - for i in range(self.backend.core.shape(images)[-3]): - channel = images[..., i, :, :] - equalized = self._equalize_channel(channel, self.value_range) - channels.append(equalized) - equalized_images = self.backend.numpy.stack(channels, axis=-3) - else: - channels = [] - for i in range(self.backend.core.shape(images)[-1]): - channel = images[..., i] - equalized = self._equalize_channel(channel, self.value_range) - channels.append(equalized) - equalized_images = self.backend.numpy.stack(channels, axis=-1) - - return self.backend.cast(equalized_images, self.compute_dtype) + def transform_images(self, images, transformation, training=True): + if training: + images = self.backend.cast(images, self.compute_dtype) + + if self.data_format == "channels_first": + channels = [] + for i in range(self.backend.core.shape(images)[-3]): + channel = images[..., i, :, :] + equalized = self._equalize_channel( + channel, self.value_range + ) + channels.append(equalized) + equalized_images = self.backend.numpy.stack(channels, axis=-3) + else: + channels = [] + for i in range(self.backend.core.shape(images)[-1]): + channel = images[..., i] + equalized = self._equalize_channel( + channel, self.value_range + ) + channels.append(equalized) + equalized_images = self.backend.numpy.stack(channels, axis=-1) + + return self.backend.cast(equalized_images, self.compute_dtype) + return images def compute_output_shape(self, input_shape): return input_shape @@ -196,14 +202,19 @@ def compute_output_shape(self, input_shape): def compute_output_spec(self, inputs, **kwargs): return inputs - def transform_bounding_boxes(self, bounding_boxes, **kwargs): + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): return bounding_boxes - def transform_labels(self, labels, transformations=None, **kwargs): + def transform_labels(self, labels, transformation, training=True): return labels def transform_segmentation_masks( - self, segmentation_masks, transformations, **kwargs + self, segmentation_masks, transformation, training=True ): return segmentation_masks diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py b/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py new file mode 100644 index 00000000000..c3255c846eb --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py @@ -0,0 +1,132 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random import SeedGenerator + + +@keras_export("keras.layers.RandomColorDegeneration") +class RandomColorDegeneration(BaseImagePreprocessingLayer): + """Randomly performs the color degeneration operation on given images. + + The sharpness operation first converts an image to gray scale, then back to + color. It then takes a weighted average between original image and the + degenerated image. This makes colors appear more dull. + + Args: + factor: A tuple of two floats or a single float. + `factor` controls the extent to which the + image sharpness is impacted. `factor=0.0` makes this layer perform a + no-op operation, while a value of 1.0 uses the degenerated result + entirely. Values between 0 and 1 result in linear interpolation + between the original image and the sharpened image. + Values should be between `0.0` and `1.0`. If a tuple is used, a + `factor` is sampled between the two values for every image + augmented. If a single float is used, a value between `0.0` and the + passed float is sampled. In order to ensure the value is always the + same, please pass a tuple with two identical floats: `(0.5, 0.5)`. + seed: Integer. Used to create a random seed. + """ + + _VALUE_RANGE_VALIDATION_ERROR = ( + "The `value_range` argument should be a list of two numbers. " + ) + + def __init__( + self, + factor, + value_range=(0, 255), + data_format=None, + seed=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self._set_factor(factor) + self._set_value_range(value_range) + self.seed = seed + self.generator = SeedGenerator(seed) + + def _set_value_range(self, value_range): + if not isinstance(value_range, (tuple, list)): + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + if len(value_range) != 2: + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + self.value_range = sorted(value_range) + + def get_random_transformation(self, data, training=True, seed=None): + if isinstance(data, dict): + images = data["images"] + else: + images = data + images_shape = self.backend.shape(images) + rank = len(images_shape) + if rank == 3: + batch_size = 1 + elif rank == 4: + batch_size = images_shape[0] + else: + raise ValueError( + "Expected the input image to be rank 3 or 4. Received: " + f"inputs.shape={images_shape}" + ) + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + factor = self.backend.random.uniform( + (batch_size, 1, 1, 1), + minval=self.factor[0], + maxval=self.factor[1], + seed=seed, + ) + factor = factor + return {"factor": factor} + + def transform_images(self, images, transformation=None, training=True): + if training: + images = self.backend.cast(images, self.compute_dtype) + factor = self.backend.cast( + transformation["factor"], self.compute_dtype + ) + degenerates = self.backend.image.rgb_to_grayscale( + images, data_format=self.data_format + ) + images = images + factor * (degenerates - images) + images = self.backend.numpy.clip( + images, self.value_range[0], self.value_range[1] + ) + images = self.backend.cast(images, self.compute_dtype) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def transform_bounding_boxes( + self, bounding_boxes, transformation, training=True + ): + return bounding_boxes + + def get_config(self): + config = super().get_config() + config.update( + { + "factor": self.factor, + "value_range": self.value_range, + "seed": self.seed, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration_test.py new file mode 100644 index 00000000000..18a0adc7c1f --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration_test.py @@ -0,0 +1,77 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +import keras +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomColorDegenerationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomColorDegeneration, + init_kwargs={ + "factor": 0.75, + "value_range": (0, 1), + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_color_degeneration_value_range(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1) + + layer = layers.RandomColorDegeneration(0.2, value_range=(0, 1)) + adjusted_image = layer(image) + + self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) + self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1)) + + def test_random_color_degeneration_no_op(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((2, 8, 8, 3)) + else: + inputs = np.random.random((2, 3, 8, 8)) + + layer = layers.RandomColorDegeneration((0.5, 0.5)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5) + + def test_random_color_degeneration_factor_zero(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((2, 8, 8, 3)) + else: + inputs = np.random.random((2, 3, 8, 8)) + layer = layers.RandomColorDegeneration(factor=(0.0, 0.0)) + result = layer(inputs) + + self.assertAllClose(inputs, result, atol=1e-3, rtol=1e-5) + + def test_random_color_degeneration_randomness(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5] + + layer = layers.RandomColorDegeneration(0.2) + adjusted_images = layer(image) + + self.assertNotAllClose(adjusted_images, image) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomColorDegeneration( + factor=0.5, data_format=data_format, seed=1337 + ) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py b/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py new file mode 100644 index 00000000000..eee6f31b8e4 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py @@ -0,0 +1,197 @@ +import keras.src.layers.preprocessing.image_preprocessing.random_brightness as random_brightness # noqa: E501 +import keras.src.layers.preprocessing.image_preprocessing.random_contrast as random_contrast # noqa: E501 +import keras.src.layers.preprocessing.image_preprocessing.random_hue as random_hue # noqa: E501 +import keras.src.layers.preprocessing.image_preprocessing.random_saturation as random_saturation # noqa: E501 +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random.seed_generator import SeedGenerator +from keras.src.utils import backend_utils + + +@keras_export("keras.layers.RandomColorJitter") +class RandomColorJitter(BaseImagePreprocessingLayer): + """RandomColorJitter class randomly apply brightness, contrast, saturation + and hue image processing operation sequentially and randomly on the + input. + + Args: + value_range: the range of values the incoming images will have. + Represented as a two number tuple written [low, high]. + This is typically either `[0, 1]` or `[0, 255]` depending + on how your preprocessing pipeline is set up. + brightness_factor: Float or a list/tuple of 2 floats between -1.0 + and 1.0. The factor is used to determine the lower bound and + upper bound of the brightness adjustment. A float value will + be chosen randomly between the limits. When -1.0 is chosen, + the output image will be black, and when 1.0 is chosen, the + image will be fully white. When only one float is provided, + eg, 0.2, then -0.2 will be used for lower bound and 0.2 will + be used for upper bound. + contrast_factor: a positive float represented as fraction of value, + or a tuple of size 2 representing lower and upper bound. When + represented as a single float, lower = upper. The contrast + factor will be randomly picked between `[1.0 - lower, 1.0 + + upper]`. For any pixel x in the channel, the output will be + `(x - mean) * factor + mean` where `mean` is the mean value + of the channel. + saturation_factor: A tuple of two floats or a single float. `factor` + controls the extent to which the image saturation is impacted. + `factor=0.5` makes this layer perform a no-op operation. + `factor=0.0` makes the image fully grayscale. `factor=1.0` + makes the image fully saturated. Values should be between + `0.0` and `1.0`. If a tuple is used, a `factor` is sampled + between the two values for every image augmented. If a single + float is used, a value between `0.0` and the passed float is + sampled. To ensure the value is always the same, pass a tuple + with two identical floats: `(0.5, 0.5)`. + hue_factor: A single float or a tuple of two floats. `factor` + controls the extent to which the image hue is impacted. + `factor=0.0` makes this layer perform a no-op operation, + while a value of `1.0` performs the most aggressive contrast + adjustment available. If a tuple is used, a `factor` is + sampled between the two values for every image augmented. + If a single float is used, a value between `0.0` and the + passed float is sampled. In order to ensure the value is + always the same, please pass a tuple with two identical + floats: `(0.5, 0.5)`. + seed: Integer. Used to create a random seed. + """ + + def __init__( + self, + value_range=(0, 255), + brightness_factor=None, + contrast_factor=None, + saturation_factor=None, + hue_factor=None, + seed=None, + data_format=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self.value_range = value_range + self.brightness_factor = brightness_factor + self.contrast_factor = contrast_factor + self.saturation_factor = saturation_factor + self.hue_factor = hue_factor + self.seed = seed + self.generator = SeedGenerator(seed) + + self.random_brightness = None + self.random_contrast = None + self.random_saturation = None + self.random_hue = None + + if self.brightness_factor is not None: + self.random_brightness = random_brightness.RandomBrightness( + factor=self.brightness_factor, + value_range=self.value_range, + seed=self.seed, + ) + + if self.contrast_factor is not None: + self.random_contrast = random_contrast.RandomContrast( + factor=self.contrast_factor, + value_range=self.value_range, + seed=self.seed, + ) + + if self.saturation_factor is not None: + self.random_saturation = random_saturation.RandomSaturation( + factor=self.saturation_factor, + value_range=self.value_range, + seed=self.seed, + ) + + if self.hue_factor is not None: + self.random_hue = random_hue.RandomHue( + factor=self.hue_factor, + value_range=self.value_range, + seed=self.seed, + ) + + def transform_images(self, images, transformation, training=True): + if training: + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + images = self.backend.cast(images, self.compute_dtype) + if self.brightness_factor is not None: + if backend_utils.in_tf_graph(): + self.random_brightness.backend.set_backend("tensorflow") + transformation = ( + self.random_brightness.get_random_transformation( + images, + seed=self._get_seed_generator(self.backend._backend), + ) + ) + images = self.random_brightness.transform_images( + images, transformation + ) + if self.contrast_factor is not None: + if backend_utils.in_tf_graph(): + self.random_contrast.backend.set_backend("tensorflow") + transformation = self.random_contrast.get_random_transformation( + images, seed=self._get_seed_generator(self.backend._backend) + ) + transformation["contrast_factor"] = self.backend.cast( + transformation["contrast_factor"], dtype=self.compute_dtype + ) + images = self.random_contrast.transform_images( + images, transformation + ) + if self.saturation_factor is not None: + if backend_utils.in_tf_graph(): + self.random_saturation.backend.set_backend("tensorflow") + transformation = ( + self.random_saturation.get_random_transformation( + images, + seed=self._get_seed_generator(self.backend._backend), + ) + ) + images = self.random_saturation.transform_images( + images, transformation + ) + if self.hue_factor is not None: + if backend_utils.in_tf_graph(): + self.random_hue.backend.set_backend("tensorflow") + transformation = self.random_hue.get_random_transformation( + images, seed=self._get_seed_generator(self.backend._backend) + ) + images = self.random_hue.transform_images( + images, transformation + ) + images = self.backend.cast(images, self.compute_dtype) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "value_range": self.value_range, + "brightness_factor": self.brightness_factor, + "contrast_factor": self.contrast_factor, + "saturation_factor": self.saturation_factor, + "hue_factor": self.hue_factor, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter_test.py new file mode 100644 index 00000000000..a465970b6b4 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter_test.py @@ -0,0 +1,135 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomColorJitterTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomColorJitter, + init_kwargs={ + "value_range": (20, 200), + "brightness_factor": 0.2, + "contrast_factor": 0.2, + "saturation_factor": 0.2, + "hue_factor": 0.2, + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_color_jitter_inference(self): + seed = 3481 + layer = layers.RandomColorJitter( + value_range=(0, 1), + brightness_factor=0.1, + contrast_factor=0.2, + saturation_factor=0.9, + hue_factor=0.1, + ) + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_brightness_only(self): + seed = 2390 + np.random.seed(seed) + + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((12, 8, 16, 3)) + else: + inputs = np.random.random((12, 3, 8, 16)) + + layer = layers.RandomColorJitter( + brightness_factor=[0.5, 0.5], seed=seed + ) + output = backend.convert_to_numpy(layer(inputs)) + + layer = layers.RandomBrightness(factor=[0.5, 0.5], seed=seed) + sub_output = backend.convert_to_numpy(layer(inputs)) + + self.assertAllClose(output, sub_output) + + def test_saturation_only(self): + seed = 2390 + np.random.seed(seed) + + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((12, 8, 16, 3)) + else: + inputs = np.random.random((12, 3, 8, 16)) + + layer = layers.RandomColorJitter( + saturation_factor=[0.5, 0.5], seed=seed + ) + output = layer(inputs) + + layer = layers.RandomSaturation(factor=[0.5, 0.5], seed=seed) + sub_output = layer(inputs) + + self.assertAllClose(output, sub_output) + + def test_hue_only(self): + seed = 2390 + np.random.seed(seed) + + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((12, 8, 16, 3)) + else: + inputs = np.random.random((12, 3, 8, 16)) + + layer = layers.RandomColorJitter(hue_factor=[0.5, 0.5], seed=seed) + output = layer(inputs) + + layer = layers.RandomHue(factor=[0.5, 0.5], seed=seed) + sub_output = layer(inputs) + + self.assertAllClose(output, sub_output) + + def test_contrast_only(self): + seed = 2390 + np.random.seed(seed) + + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((12, 8, 16, 3)) + else: + inputs = np.random.random((12, 3, 8, 16)) + + layer = layers.RandomColorJitter(contrast_factor=[0.5, 0.5], seed=seed) + output = layer(inputs) + + layer = layers.RandomContrast(factor=[0.5, 0.5], seed=seed) + sub_output = layer(inputs) + + self.assertAllClose(output, sub_output) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomColorJitter( + value_range=(0, 1), + brightness_factor=0.1, + contrast_factor=0.2, + saturation_factor=0.9, + hue_factor=0.1, + ) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_contrast.py b/keras/src/layers/preprocessing/image_preprocessing/random_contrast.py index c9525cb651f..ab793b266e0 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_contrast.py @@ -40,14 +40,19 @@ class RandomContrast(BaseImagePreprocessingLayer): `[1.0 - lower, 1.0 + upper]`. For any pixel x in the channel, the output will be `(x - mean) * factor + mean` where `mean` is the mean value of the channel. + value_range: the range of values the incoming images will have. + Represented as a two-number tuple written `[low, high]`. This is + typically either `[0, 1]` or `[0, 255]` depending on how your + preprocessing pipeline is set up. seed: Integer. Used to create a random seed. """ _FACTOR_BOUNDS = (0, 1) - def __init__(self, factor, seed=None, **kwargs): + def __init__(self, factor, value_range=(0, 255), seed=None, **kwargs): super().__init__(**kwargs) self._set_factor(factor) + self.value_range = value_range self.seed = seed self.generator = SeedGenerator(seed) @@ -89,7 +94,9 @@ def transform_images(self, images, transformation, training=True): if training: constrast_factor = transformation["contrast_factor"] outputs = self._adjust_constrast(images, constrast_factor) - outputs = self.backend.numpy.clip(outputs, 0, 255) + outputs = self.backend.numpy.clip( + outputs, self.value_range[0], self.value_range[1] + ) self.backend.numpy.reshape(outputs, self.backend.shape(images)) return outputs return images @@ -135,6 +142,7 @@ def compute_output_shape(self, input_shape): def get_config(self): config = { "factor": self.factor, + "value_range": self.value_range, "seed": self.seed, } base_config = super().get_config() diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_contrast_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_contrast_test.py index 8972d88f33e..a0f9cc24cf5 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_contrast_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_contrast_test.py @@ -14,6 +14,7 @@ def test_layer(self): layers.RandomContrast, init_kwargs={ "factor": 0.75, + "value_range": (0, 255), "seed": 1, }, input_shape=(8, 3, 4, 3), @@ -24,6 +25,7 @@ def test_layer(self): layers.RandomContrast, init_kwargs={ "factor": 0.75, + "value_range": (0, 255), "seed": 1, "data_format": "channels_first", }, @@ -32,21 +34,67 @@ def test_layer(self): expected_output_shape=(8, 3, 4, 4), ) - def test_random_contrast(self): + def test_random_contrast_with_value_range_0_to_255(self): seed = 9809 np.random.seed(seed) - inputs = np.random.random((12, 8, 16, 3)) - layer = layers.RandomContrast(factor=0.5, seed=seed) - outputs = layer(inputs) + + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((12, 8, 16, 3)) + height_axis = -3 + width_axis = -2 + else: + inputs = np.random.random((12, 3, 8, 16)) + height_axis = -2 + width_axis = -1 + + inputs = backend.convert_to_tensor(inputs, dtype="float32") + layer = layers.RandomContrast( + factor=0.5, value_range=(0, 255), seed=seed + ) + transformation = layer.get_random_transformation(inputs, training=True) + outputs = layer.transform_images(inputs, transformation, training=True) + + # Actual contrast arithmetic + np.random.seed(seed) + factor = backend.convert_to_numpy(transformation["contrast_factor"]) + inputs = backend.convert_to_numpy(inputs) + inp_mean = np.mean(inputs, axis=height_axis, keepdims=True) + inp_mean = np.mean(inp_mean, axis=width_axis, keepdims=True) + actual_outputs = (inputs - inp_mean) * factor + inp_mean + outputs = backend.convert_to_numpy(outputs) + actual_outputs = np.clip(actual_outputs, 0, 255) + + self.assertAllClose(outputs, actual_outputs) + + def test_random_contrast_with_value_range_0_to_1(self): + seed = 9809 + np.random.seed(seed) + + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((12, 8, 16, 3)) + height_axis = -3 + width_axis = -2 + else: + inputs = np.random.random((12, 3, 8, 16)) + height_axis = -2 + width_axis = -1 + + inputs = backend.convert_to_tensor(inputs, dtype="float32") + layer = layers.RandomContrast(factor=0.5, value_range=(0, 1), seed=seed) + transformation = layer.get_random_transformation(inputs, training=True) + outputs = layer.transform_images(inputs, transformation, training=True) # Actual contrast arithmetic np.random.seed(seed) - factor = np.random.uniform(0.5, 1.5) - inp_mean = np.mean(inputs, axis=-3, keepdims=True) - inp_mean = np.mean(inp_mean, axis=-2, keepdims=True) + factor = backend.convert_to_numpy(transformation["contrast_factor"]) + inputs = backend.convert_to_numpy(inputs) + inp_mean = np.mean(inputs, axis=height_axis, keepdims=True) + inp_mean = np.mean(inp_mean, axis=width_axis, keepdims=True) actual_outputs = (inputs - inp_mean) * factor + inp_mean outputs = backend.convert_to_numpy(outputs) - actual_outputs = np.clip(outputs, 0, 255) + actual_outputs = np.clip(actual_outputs, 0, 1) self.assertAllClose(outputs, actual_outputs) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_crop.py b/keras/src/layers/preprocessing/image_preprocessing/random_crop.py index 62571e69a93..f67469089f9 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_crop.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_crop.py @@ -122,59 +122,60 @@ def get_random_transformation(self, data, training=True, seed=None): return h_start, w_start def transform_images(self, images, transformation, training=True): - images = self.backend.cast(images, self.compute_dtype) - crop_box_hstart, crop_box_wstart = transformation - crop_height = self.height - crop_width = self.width + if training: + images = self.backend.cast(images, self.compute_dtype) + crop_box_hstart, crop_box_wstart = transformation + crop_height = self.height + crop_width = self.width - if self.data_format == "channels_last": - if len(images.shape) == 4: - images = images[ - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - :, - ] - else: - images = images[ - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - :, - ] - else: - if len(images.shape) == 4: - images = images[ - :, - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - ] + if self.data_format == "channels_last": + if len(images.shape) == 4: + images = images[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + images = images[ + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] else: - images = images[ - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - ] + if len(images.shape) == 4: + images = images[ + :, + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + else: + images = images[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] - shape = self.backend.shape(images) - new_height = shape[self.height_axis] - new_width = shape[self.width_axis] - if ( - not isinstance(new_height, int) - or not isinstance(new_width, int) - or new_height != self.height - or new_width != self.width - ): - # Resize images if size mismatch or - # if size mismatch cannot be determined - # (in the case of a TF dynamic shape). - images = self.backend.image.resize( - images, - size=(self.height, self.width), - data_format=self.data_format, - ) - # Resize may have upcasted the outputs - images = self.backend.cast(images, self.compute_dtype) + shape = self.backend.shape(images) + new_height = shape[self.height_axis] + new_width = shape[self.width_axis] + if ( + not isinstance(new_height, int) + or not isinstance(new_width, int) + or new_height != self.height + or new_width != self.width + ): + # Resize images if size mismatch or + # if size mismatch cannot be determined + # (in the case of a TF dynamic shape). + images = self.backend.image.resize( + images, + size=(self.height, self.width), + data_format=self.data_format, + ) + # Resize may have upcasted the outputs + images = self.backend.cast(images, self.compute_dtype) return images def transform_labels(self, labels, transformation, training=True): @@ -197,56 +198,59 @@ def transform_bounding_boxes( "labels": (num_boxes, num_classes), } """ - h_start, w_start = transformation - if not self.backend.is_tensor(bounding_boxes["boxes"]): - bounding_boxes = densify_bounding_boxes( - bounding_boxes, backend=self.backend - ) - boxes = bounding_boxes["boxes"] - # Convert to a standard xyxy as operations are done xyxy by default. - boxes = convert_format( - boxes=boxes, - source=self.bounding_box_format, - target="xyxy", - height=self.height, - width=self.width, - ) - h_start = self.backend.cast(h_start, boxes.dtype) - w_start = self.backend.cast(w_start, boxes.dtype) - if len(self.backend.shape(boxes)) == 3: - boxes = self.backend.numpy.stack( - [ - self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0), - self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0), - self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0), - self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0), - ], - axis=-1, - ) - else: - boxes = self.backend.numpy.stack( - [ - self.backend.numpy.maximum(boxes[:, 0] - h_start, 0), - self.backend.numpy.maximum(boxes[:, 1] - w_start, 0), - self.backend.numpy.maximum(boxes[:, 2] - h_start, 0), - self.backend.numpy.maximum(boxes[:, 3] - w_start, 0), - ], - axis=-1, + + if training: + h_start, w_start = transformation + if not self.backend.is_tensor(bounding_boxes["boxes"]): + bounding_boxes = densify_bounding_boxes( + bounding_boxes, backend=self.backend + ) + boxes = bounding_boxes["boxes"] + # Convert to a standard xyxy as operations are done xyxy by default. + boxes = convert_format( + boxes=boxes, + source=self.bounding_box_format, + target="xyxy", + height=self.height, + width=self.width, ) + h_start = self.backend.cast(h_start, boxes.dtype) + w_start = self.backend.cast(w_start, boxes.dtype) + if len(self.backend.shape(boxes)) == 3: + boxes = self.backend.numpy.stack( + [ + self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0), + self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0), + self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0), + self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0), + ], + axis=-1, + ) + else: + boxes = self.backend.numpy.stack( + [ + self.backend.numpy.maximum(boxes[:, 0] - h_start, 0), + self.backend.numpy.maximum(boxes[:, 1] - w_start, 0), + self.backend.numpy.maximum(boxes[:, 2] - h_start, 0), + self.backend.numpy.maximum(boxes[:, 3] - w_start, 0), + ], + axis=-1, + ) - # Convert to user defined bounding box format - boxes = convert_format( - boxes=boxes, - source="xyxy", - target=self.bounding_box_format, - height=self.height, - width=self.width, - ) + # Convert to user defined bounding box format + boxes = convert_format( + boxes=boxes, + source="xyxy", + target=self.bounding_box_format, + height=self.height, + width=self.width, + ) - return { - "boxes": boxes, - "labels": bounding_boxes["labels"], - } + return { + "boxes": boxes, + "labels": bounding_boxes["labels"], + } + return bounding_boxes def transform_segmentation_masks( self, segmentation_masks, transformation, training=True diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_flip.py b/keras/src/layers/preprocessing/image_preprocessing/random_flip.py index 519379685d1..83deff5fc05 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_flip.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_flip.py @@ -101,9 +101,6 @@ def transform_bounding_boxes( transformation, training=True, ): - if backend_utils.in_tf_graph(): - self.backend.set_backend("tensorflow") - def _flip_boxes_horizontal(boxes): x1, x2, x3, x4 = self.backend.numpy.split(boxes, 4, axis=-1) outputs = self.backend.numpy.concatenate( @@ -134,46 +131,50 @@ def _transform_xyxy(boxes, box_flips): ) return bboxes - flips = self.backend.numpy.squeeze(transformation["flips"], axis=-1) + if training: + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") - if self.data_format == "channels_first": - height_axis = -2 - width_axis = -1 - else: - height_axis = -3 - width_axis = -2 + flips = self.backend.numpy.squeeze(transformation["flips"], axis=-1) - input_height, input_width = ( - transformation["input_shape"][height_axis], - transformation["input_shape"][width_axis], - ) + if self.data_format == "channels_first": + height_axis = -2 + width_axis = -1 + else: + height_axis = -3 + width_axis = -2 - bounding_boxes = convert_format( - bounding_boxes, - source=self.bounding_box_format, - target="rel_xyxy", - height=input_height, - width=input_width, - ) + input_height, input_width = ( + transformation["input_shape"][height_axis], + transformation["input_shape"][width_axis], + ) - bounding_boxes["boxes"] = _transform_xyxy(bounding_boxes, flips) + bounding_boxes = convert_format( + bounding_boxes, + source=self.bounding_box_format, + target="rel_xyxy", + height=input_height, + width=input_width, + ) - bounding_boxes = clip_to_image_size( - bounding_boxes=bounding_boxes, - height=input_height, - width=input_width, - bounding_box_format="xyxy", - ) + bounding_boxes["boxes"] = _transform_xyxy(bounding_boxes, flips) - bounding_boxes = convert_format( - bounding_boxes, - source="rel_xyxy", - target=self.bounding_box_format, - height=input_height, - width=input_width, - ) + bounding_boxes = clip_to_image_size( + bounding_boxes=bounding_boxes, + height=input_height, + width=input_width, + bounding_box_format="xyxy", + ) + + bounding_boxes = convert_format( + bounding_boxes, + source="rel_xyxy", + target=self.bounding_box_format, + height=input_height, + width=input_width, + ) - self.backend.reset() + self.backend.reset() return bounding_boxes diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py index 804e9323a0f..e03a626852e 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py @@ -71,17 +71,21 @@ def get_random_transformation(self, images, training=True, seed=None): ) return should_apply - def transform_images(self, images, transformations=None, **kwargs): - should_apply = ( - transformations - if transformations is not None - else self.get_random_transformation(images) - ) + def transform_images(self, images, transformation, training=True): + if training: + should_apply = ( + transformation + if transformation is not None + else self.get_random_transformation(images) + ) - grayscale_images = self.backend.image.rgb_to_grayscale( - images, data_format=self.data_format - ) - return self.backend.numpy.where(should_apply, grayscale_images, images) + grayscale_images = self.backend.image.rgb_to_grayscale( + images, data_format=self.data_format + ) + return self.backend.numpy.where( + should_apply, grayscale_images, images + ) + return images def compute_output_shape(self, input_shape): return input_shape diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_hue.py b/keras/src/layers/preprocessing/image_preprocessing/random_hue.py index d439beb905d..43ee63ad62b 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_hue.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_hue.py @@ -44,7 +44,12 @@ class RandomHue(BaseImagePreprocessingLayer): _FACTOR_BOUNDS = (0, 1) def __init__( - self, factor, value_range, data_format=None, seed=None, **kwargs + self, + factor, + value_range=(0, 255), + data_format=None, + seed=None, + **kwargs, ): super().__init__(data_format=data_format, **kwargs) self._set_factor(factor) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_hue_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_hue_test.py index cbfb355ba35..f115612309d 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_hue_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_hue_test.py @@ -31,15 +31,24 @@ def test_random_hue_inference(self): output = layer(inputs, training=False) self.assertAllClose(inputs, output) - def test_random_hue_value_range(self): + def test_random_hue_value_range_0_to_1(self): image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1) - layer = layers.RandomHue(0.2, (0, 255)) + layer = layers.RandomHue(0.2, (0, 1)) adjusted_image = layer(image) self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1)) + def test_random_hue_value_range_0_to_255(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=255) + + layer = layers.RandomHue(0.2, (0, 255)) + adjusted_image = layer(image) + + self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) + self.assertTrue(keras.ops.numpy.all(adjusted_image <= 255)) + def test_random_hue_no_change_with_zero_factor(self): data_format = backend.config.image_data_format() if data_format == "channels_last": diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_posterization.py b/keras/src/layers/preprocessing/image_preprocessing/random_posterization.py new file mode 100644 index 00000000000..55e7536724f --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_posterization.py @@ -0,0 +1,151 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) + + +@keras_export("keras.layers.RandomPosterization") +class RandomPosterization(BaseImagePreprocessingLayer): + """Reduces the number of bits for each color channel. + + References: + - [AutoAugment: Learning Augmentation Policies from Data](https://arxiv.org/abs/1805.09501) + - [RandAugment: Practical automated data augmentation with a reduced search space](https://arxiv.org/abs/1909.13719) + + Args: + value_range: a tuple or a list of two elements. The first value + represents the lower bound for values in passed images, the second + represents the upper bound. Images passed to the layer should have + values within `value_range`. Defaults to `(0, 255)`. + factor: integer, the number of bits to keep for each channel. Must be a + value between 1-8. + """ + + _USE_BASE_FACTOR = False + _FACTOR_BOUNDS = (1, 8) + _MAX_FACTOR = 8 + _VALUE_RANGE_VALIDATION_ERROR = ( + "The `value_range` argument should be a list of two numbers. " + ) + + def __init__( + self, + factor, + value_range=(0, 255), + data_format=None, + seed=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self._set_factor(factor) + self._set_value_range(value_range) + self.seed = seed + self.generator = self.backend.random.SeedGenerator(seed) + + def _set_value_range(self, value_range): + if not isinstance(value_range, (tuple, list)): + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + if len(value_range) != 2: + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + self.value_range = sorted(value_range) + + def get_random_transformation(self, data, training=True, seed=None): + if isinstance(data, dict): + images = data["images"] + else: + images = data + images_shape = self.backend.shape(images) + rank = len(images_shape) + if rank == 3: + batch_size = 1 + elif rank == 4: + batch_size = images_shape[0] + else: + raise ValueError( + "Expected the input image to be rank 3 or 4. Received: " + f"inputs.shape={images_shape}" + ) + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + if self.factor[0] != self.factor[1]: + factor = self.backend.random.randint( + (batch_size,), + minval=self.factor[0], + maxval=self.factor[1], + seed=seed, + dtype="uint8", + ) + else: + factor = ( + self.backend.numpy.ones((batch_size,), dtype="uint8") + * self.factor[0] + ) + + shift_factor = self._MAX_FACTOR - factor + return {"shift_factor": shift_factor} + + def transform_images(self, images, transformation=None, training=True): + if training: + shift_factor = transformation["shift_factor"] + + shift_factor = self.backend.numpy.reshape( + shift_factor, self.backend.shape(shift_factor) + (1, 1, 1) + ) + + images = self._transform_value_range( + images, + original_range=self.value_range, + target_range=(0, 255), + dtype=self.compute_dtype, + ) + + images = self.backend.cast(images, "uint8") + images = self.backend.numpy.bitwise_left_shift( + self.backend.numpy.bitwise_right_shift(images, shift_factor), + shift_factor, + ) + images = self.backend.cast(images, self.compute_dtype) + + images = self._transform_value_range( + images, + original_range=(0, 255), + target_range=self.value_range, + dtype=self.compute_dtype, + ) + + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def transform_bounding_boxes( + self, bounding_boxes, transformation, training=True + ): + return bounding_boxes + + def get_config(self): + config = super().get_config() + config.update( + { + "factor": self.factor, + "value_range": self.value_range, + "seed": self.seed, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_posterization_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_posterization_test.py new file mode 100644 index 00000000000..347f82a3a96 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_posterization_test.py @@ -0,0 +1,85 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +import keras +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomPosterizationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomPosterization, + init_kwargs={ + "factor": 1, + "value_range": (20, 200), + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_posterization_inference(self): + seed = 3481 + layer = layers.RandomPosterization(1, [0, 255]) + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_random_posterization_basic(self): + seed = 3481 + layer = layers.RandomPosterization( + 1, [0, 255], data_format="channels_last", seed=seed + ) + np.random.seed(seed) + inputs = np.asarray( + [[[128.0, 235.0, 87.0], [12.0, 1.0, 23.0], [24.0, 18.0, 121.0]]] + ) + output = layer(inputs) + expected_output = np.asarray( + [[[128.0, 128.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]] + ) + self.assertAllClose(expected_output, output) + + def test_random_posterization_value_range_0_to_1(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1) + + layer = layers.RandomPosterization(1, [0, 1.0]) + adjusted_image = layer(image) + + self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) + self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1)) + + def test_random_posterization_value_range_0_to_255(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=255) + + layer = layers.RandomPosterization(1, [0, 255]) + adjusted_image = layer(image) + + self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) + self.assertTrue(keras.ops.numpy.all(adjusted_image <= 255)) + + def test_random_posterization_randomness(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1) + + layer = layers.RandomPosterization(1, [0, 255]) + adjusted_images = layer(image) + + self.assertNotAllClose(adjusted_images, image) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomPosterization(1, [0, 255]) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_rotation.py b/keras/src/layers/preprocessing/image_preprocessing/random_rotation.py index 70221b9fa69..ea1e4b882fe 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_rotation.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_rotation.py @@ -131,37 +131,38 @@ def transform_bounding_boxes( transformation, training=True, ): - ops = self.backend - boxes = bounding_boxes["boxes"] - height = transformation["image_height"] - width = transformation["image_width"] - batch_size = transformation["batch_size"] - boxes = converters.affine_transform( - boxes=boxes, - angle=transformation["angle"], - translate_x=ops.numpy.zeros([batch_size]), - translate_y=ops.numpy.zeros([batch_size]), - scale=ops.numpy.ones([batch_size]), - shear_x=ops.numpy.zeros([batch_size]), - shear_y=ops.numpy.zeros([batch_size]), - height=height, - width=width, - ) + if training: + ops = self.backend + boxes = bounding_boxes["boxes"] + height = transformation["image_height"] + width = transformation["image_width"] + batch_size = transformation["batch_size"] + boxes = converters.affine_transform( + boxes=boxes, + angle=transformation["angle"], + translate_x=ops.numpy.zeros([batch_size]), + translate_y=ops.numpy.zeros([batch_size]), + scale=ops.numpy.ones([batch_size]), + shear_x=ops.numpy.zeros([batch_size]), + shear_y=ops.numpy.zeros([batch_size]), + height=height, + width=width, + ) - bounding_boxes["boxes"] = boxes - bounding_boxes = converters.clip_to_image_size( - bounding_boxes, - height=height, - width=width, - bounding_box_format="xyxy", - ) - bounding_boxes = converters.convert_format( - bounding_boxes, - source="xyxy", - target=self.bounding_box_format, - height=height, - width=width, - ) + bounding_boxes["boxes"] = boxes + bounding_boxes = converters.clip_to_image_size( + bounding_boxes, + height=height, + width=width, + bounding_box_format="xyxy", + ) + bounding_boxes = converters.convert_format( + bounding_boxes, + source="xyxy", + target=self.bounding_box_format, + height=height, + width=width, + ) return bounding_boxes def transform_segmentation_masks( diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py b/keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py new file mode 100644 index 00000000000..f6f4edb3b81 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py @@ -0,0 +1,168 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random import SeedGenerator + + +@keras_export("keras.layers.RandomSharpness") +class RandomSharpness(BaseImagePreprocessingLayer): + """Randomly performs the sharpness operation on given images. + + The sharpness operation first performs a blur, then blends between the + original image and the processed image. This operation adjusts the clarity + of the edges in an image, ranging from blurred to enhanced sharpness. + + Args: + factor: A tuple of two floats or a single float. + `factor` controls the extent to which the image sharpness + is impacted. `factor=0.0` results in a fully blurred image, + `factor=0.5` applies no operation (preserving the original image), + and `factor=1.0` enhances the sharpness beyond the original. Values + should be between `0.0` and `1.0`. If a tuple is used, a `factor` + is sampled between the two values for every image augmented. + If a single float is used, a value between `0.0` and the passed + float is sampled. To ensure the value is always the same, + pass a tuple with two identical floats: `(0.5, 0.5)`. + value_range: the range of values the incoming images will have. + Represented as a two-number tuple written `[low, high]`. This is + typically either `[0, 1]` or `[0, 255]` depending on how your + preprocessing pipeline is set up. + seed: Integer. Used to create a random seed. + """ + + _USE_BASE_FACTOR = False + _FACTOR_BOUNDS = (0, 1) + + _VALUE_RANGE_VALIDATION_ERROR = ( + "The `value_range` argument should be a list of two numbers. " + ) + + def __init__( + self, + factor, + value_range=(0, 255), + data_format=None, + seed=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self._set_factor(factor) + self._set_value_range(value_range) + self.seed = seed + self.generator = SeedGenerator(seed) + + def _set_value_range(self, value_range): + if not isinstance(value_range, (tuple, list)): + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + if len(value_range) != 2: + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + self.value_range = sorted(value_range) + + def get_random_transformation(self, data, training=True, seed=None): + if isinstance(data, dict): + images = data["images"] + else: + images = data + images_shape = self.backend.shape(images) + rank = len(images_shape) + if rank == 3: + batch_size = 1 + elif rank == 4: + batch_size = images_shape[0] + else: + raise ValueError( + "Expected the input image to be rank 3 or 4. Received: " + f"inputs.shape={images_shape}" + ) + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + factor = self.backend.random.uniform( + (batch_size,), + minval=self.factor[0], + maxval=self.factor[1], + seed=seed, + ) + return {"factor": factor} + + def transform_images(self, images, transformation=None, training=True): + images = self.backend.cast(images, self.compute_dtype) + if training: + if self.data_format == "channels_first": + images = self.backend.numpy.swapaxes(images, -3, -1) + + sharpness_factor = self.backend.cast( + transformation["factor"] * 2, dtype=self.compute_dtype + ) + sharpness_factor = self.backend.numpy.reshape( + sharpness_factor, (-1, 1, 1, 1) + ) + + num_channels = self.backend.shape(images)[-1] + + a, b = 1.0 / 13.0, 5.0 / 13.0 + kernel = self.backend.convert_to_tensor( + [[a, a, a], [a, b, a], [a, a, a]], dtype=self.compute_dtype + ) + kernel = self.backend.numpy.reshape(kernel, (3, 3, 1, 1)) + kernel = self.backend.numpy.tile(kernel, [1, 1, num_channels, 1]) + kernel = self.backend.cast(kernel, self.compute_dtype) + + smoothed_image = self.backend.nn.depthwise_conv( + images, + kernel, + strides=1, + padding="same", + data_format="channels_last", + ) + + smoothed_image = self.backend.cast( + smoothed_image, dtype=self.compute_dtype + ) + images = images + (1.0 - sharpness_factor) * ( + smoothed_image - images + ) + + images = self.backend.numpy.clip( + images, self.value_range[0], self.value_range[1] + ) + + if self.data_format == "channels_first": + images = self.backend.numpy.swapaxes(images, -3, -1) + + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def transform_bounding_boxes( + self, bounding_boxes, transformation, training=True + ): + return bounding_boxes + + def get_config(self): + config = super().get_config() + config.update( + { + "factor": self.factor, + "value_range": self.value_range, + "seed": self.seed, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_sharpness_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_sharpness_test.py new file mode 100644 index 00000000000..5cf3b10c867 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_sharpness_test.py @@ -0,0 +1,65 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +import keras +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomSharpnessTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomSharpness, + init_kwargs={ + "factor": 0.75, + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_sharpness_value_range(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1) + + layer = layers.RandomSharpness(0.2) + adjusted_image = layer(image) + + self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) + self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1)) + + def test_random_sharpness_no_op(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((2, 8, 8, 3)) + else: + inputs = np.random.random((2, 3, 8, 8)) + + layer = layers.RandomSharpness((0.5, 0.5)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5) + + def test_random_sharpness_randomness(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5] + + layer = layers.RandomSharpness(0.2) + adjusted_images = layer(image) + + self.assertNotAllClose(adjusted_images, image) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomSharpness( + factor=0.5, data_format=data_format, seed=1337 + ) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_shear.py b/keras/src/layers/preprocessing/image_preprocessing/random_shear.py new file mode 100644 index 00000000000..74390c77c77 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_shear.py @@ -0,0 +1,401 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + clip_to_image_size, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + convert_format, +) +from keras.src.random.seed_generator import SeedGenerator +from keras.src.utils import backend_utils + + +@keras_export("keras.layers.RandomShear") +class RandomShear(BaseImagePreprocessingLayer): + """A preprocessing layer that randomly applies shear transformations to + images. + + This layer shears the input images along the x-axis and/or y-axis by a + randomly selected factor within the specified range. The shear + transformation is applied to each image independently in a batch. Empty + regions created during the transformation are filled according to the + `fill_mode` and `fill_value` parameters. + + Args: + x_factor: A tuple of two floats. For each augmented image, a value + is sampled from the provided range. If a float is passed, the + range is interpreted as `(0, x_factor)`. Values represent a + percentage of the image to shear over. For example, 0.3 shears + pixels up to 30% of the way across the image. All provided values + should be positive. + y_factor: A tuple of two floats. For each augmented image, a value + is sampled from the provided range. If a float is passed, the + range is interpreted as `(0, y_factor)`. Values represent a + percentage of the image to shear over. For example, 0.3 shears + pixels up to 30% of the way across the image. All provided values + should be positive. + interpolation: Interpolation mode. Supported values: `"nearest"`, + `"bilinear"`. + fill_mode: Points outside the boundaries of the input are filled + according to the given mode. Available methods are `"constant"`, + `"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`. + - `"reflect"`: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the + last pixel. + - `"constant"`: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge + with the same constant value `k` specified by `fill_value`. + - `"wrap"`: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - `"nearest"`: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + Note that when using torch backend, `"reflect"` is redirected to + `"mirror"` `(c d c b | a b c d | c b a b)` because torch does + not support `"reflect"`. + Note that torch backend does not support `"wrap"`. + fill_value: A float representing the value to be filled outside the + boundaries when `fill_mode="constant"`. + seed: Integer. Used to create a random seed. + """ + + _USE_BASE_FACTOR = False + _FACTOR_BOUNDS = (0, 1) + _FACTOR_VALIDATION_ERROR = ( + "The `factor` argument should be a number (or a list of two numbers) " + "in the range [0, 1.0]. " + ) + _SUPPORTED_FILL_MODE = ("reflect", "wrap", "constant", "nearest") + _SUPPORTED_INTERPOLATION = ("nearest", "bilinear") + + def __init__( + self, + x_factor=0.0, + y_factor=0.0, + interpolation="bilinear", + fill_mode="reflect", + fill_value=0.0, + data_format=None, + seed=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self.x_factor = self._set_factor_with_name(x_factor, "x_factor") + self.y_factor = self._set_factor_with_name(y_factor, "y_factor") + + if fill_mode not in self._SUPPORTED_FILL_MODE: + raise NotImplementedError( + f"Unknown `fill_mode` {fill_mode}. Expected of one " + f"{self._SUPPORTED_FILL_MODE}." + ) + if interpolation not in self._SUPPORTED_INTERPOLATION: + raise NotImplementedError( + f"Unknown `interpolation` {interpolation}. Expected of one " + f"{self._SUPPORTED_INTERPOLATION}." + ) + + self.fill_mode = fill_mode + self.fill_value = fill_value + self.interpolation = interpolation + self.seed = seed + self.generator = SeedGenerator(seed) + self.supports_jit = False + + def _set_factor_with_name(self, factor, factor_name): + if isinstance(factor, (tuple, list)): + if len(factor) != 2: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: {factor_name}={factor}" + ) + self._check_factor_range(factor[0]) + self._check_factor_range(factor[1]) + lower, upper = sorted(factor) + elif isinstance(factor, (int, float)): + self._check_factor_range(factor) + factor = abs(factor) + lower, upper = [-factor, factor] + else: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: {factor_name}={factor}" + ) + return lower, upper + + def _check_factor_range(self, input_number): + if input_number > 1.0 or input_number < 0.0: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: input_number={input_number}" + ) + + def get_random_transformation(self, data, training=True, seed=None): + if not training: + return None + + if isinstance(data, dict): + images = data["images"] + else: + images = data + + images_shape = self.backend.shape(images) + if len(images_shape) == 3: + batch_size = 1 + else: + batch_size = images_shape[0] + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + invert = self.backend.random.uniform( + minval=0, + maxval=1, + shape=[batch_size, 1], + seed=seed, + dtype=self.compute_dtype, + ) + invert = self.backend.numpy.where( + invert > 0.5, + -self.backend.numpy.ones_like(invert), + self.backend.numpy.ones_like(invert), + ) + + shear_y = self.backend.random.uniform( + minval=self.y_factor[0], + maxval=self.y_factor[1], + shape=[batch_size, 1], + seed=seed, + dtype=self.compute_dtype, + ) + shear_x = self.backend.random.uniform( + minval=self.x_factor[0], + maxval=self.x_factor[1], + shape=[batch_size, 1], + seed=seed, + dtype=self.compute_dtype, + ) + shear_factor = ( + self.backend.cast( + self.backend.numpy.concatenate([shear_x, shear_y], axis=1), + dtype=self.compute_dtype, + ) + * invert + ) + return {"shear_factor": shear_factor, "input_shape": images_shape} + + def transform_images(self, images, transformation, training=True): + images = self.backend.cast(images, self.compute_dtype) + if training: + return self._shear_inputs(images, transformation) + return images + + def _shear_inputs(self, inputs, transformation): + if transformation is None: + return inputs + + inputs_shape = self.backend.shape(inputs) + unbatched = len(inputs_shape) == 3 + if unbatched: + inputs = self.backend.numpy.expand_dims(inputs, axis=0) + + shear_factor = transformation["shear_factor"] + outputs = self.backend.image.affine_transform( + inputs, + transform=self._get_shear_matrix(shear_factor), + interpolation=self.interpolation, + fill_mode=self.fill_mode, + fill_value=self.fill_value, + data_format=self.data_format, + ) + + if unbatched: + outputs = self.backend.numpy.squeeze(outputs, axis=0) + return outputs + + def _get_shear_matrix(self, shear_factors): + num_shear_factors = self.backend.shape(shear_factors)[0] + + # The shear matrix looks like: + # [[1 s_x 0] + # [s_y 1 0] + # [0 0 1]] + + return self.backend.numpy.stack( + [ + self.backend.numpy.ones((num_shear_factors,)), + shear_factors[:, 0], + self.backend.numpy.zeros((num_shear_factors,)), + shear_factors[:, 1], + self.backend.numpy.ones((num_shear_factors,)), + self.backend.numpy.zeros((num_shear_factors,)), + self.backend.numpy.zeros((num_shear_factors,)), + self.backend.numpy.zeros((num_shear_factors,)), + ], + axis=1, + ) + + def transform_labels(self, labels, transformation, training=True): + return labels + + def get_transformed_x_y(self, x, y, transform): + a0, a1, a2, b0, b1, b2, c0, c1 = self.backend.numpy.split( + transform, 8, axis=-1 + ) + + k = c0 * x + c1 * y + 1 + x_transformed = (a0 * x + a1 * y + a2) / k + y_transformed = (b0 * x + b1 * y + b2) / k + return x_transformed, y_transformed + + def get_shifted_bbox(self, bounding_boxes, w_shift_factor, h_shift_factor): + bboxes = bounding_boxes["boxes"] + x1, x2, x3, x4 = self.backend.numpy.split(bboxes, 4, axis=-1) + + w_shift_factor = self.backend.convert_to_tensor( + w_shift_factor, dtype=x1.dtype + ) + h_shift_factor = self.backend.convert_to_tensor( + h_shift_factor, dtype=x1.dtype + ) + + if len(bboxes.shape) == 3: + w_shift_factor = self.backend.numpy.expand_dims(w_shift_factor, -1) + h_shift_factor = self.backend.numpy.expand_dims(h_shift_factor, -1) + + bounding_boxes["boxes"] = self.backend.numpy.concatenate( + [ + x1 - w_shift_factor, + x2 - h_shift_factor, + x3 - w_shift_factor, + x4 - h_shift_factor, + ], + axis=-1, + ) + return bounding_boxes + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + def _get_height_width(transformation): + if self.data_format == "channels_first": + height_axis = -2 + width_axis = -1 + else: + height_axis = -3 + width_axis = -2 + input_height, input_width = ( + transformation["input_shape"][height_axis], + transformation["input_shape"][width_axis], + ) + return input_height, input_width + + if training: + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + + input_height, input_width = _get_height_width(transformation) + + bounding_boxes = convert_format( + bounding_boxes, + source=self.bounding_box_format, + target="rel_xyxy", + height=input_height, + width=input_width, + dtype=self.compute_dtype, + ) + + bounding_boxes = self._shear_bboxes(bounding_boxes, transformation) + + bounding_boxes = clip_to_image_size( + bounding_boxes=bounding_boxes, + height=input_height, + width=input_width, + bounding_box_format="rel_xyxy", + ) + + bounding_boxes = convert_format( + bounding_boxes, + source="rel_xyxy", + target=self.bounding_box_format, + height=input_height, + width=input_width, + dtype=self.compute_dtype, + ) + + self.backend.reset() + + return bounding_boxes + + def _shear_bboxes(self, bounding_boxes, transformation): + shear_factor = self.backend.cast( + transformation["shear_factor"], dtype=self.compute_dtype + ) + shear_x_amount, shear_y_amount = self.backend.numpy.split( + shear_factor, 2, axis=-1 + ) + + x1, y1, x2, y2 = self.backend.numpy.split( + bounding_boxes["boxes"], 4, axis=-1 + ) + x1 = self.backend.numpy.squeeze(x1, axis=-1) + y1 = self.backend.numpy.squeeze(y1, axis=-1) + x2 = self.backend.numpy.squeeze(x2, axis=-1) + y2 = self.backend.numpy.squeeze(y2, axis=-1) + + if shear_x_amount is not None: + x1_top = x1 - (shear_x_amount * y1) + x1_bottom = x1 - (shear_x_amount * y2) + x1 = self.backend.numpy.where(shear_x_amount < 0, x1_top, x1_bottom) + + x2_top = x2 - (shear_x_amount * y1) + x2_bottom = x2 - (shear_x_amount * y2) + x2 = self.backend.numpy.where(shear_x_amount < 0, x2_bottom, x2_top) + + if shear_y_amount is not None: + y1_left = y1 - (shear_y_amount * x1) + y1_right = y1 - (shear_y_amount * x2) + y1 = self.backend.numpy.where(shear_y_amount > 0, y1_right, y1_left) + + y2_left = y2 - (shear_y_amount * x1) + y2_right = y2 - (shear_y_amount * x2) + y2 = self.backend.numpy.where(shear_y_amount > 0, y2_left, y2_right) + + boxes = self.backend.numpy.concatenate( + [ + self.backend.numpy.expand_dims(x1, axis=-1), + self.backend.numpy.expand_dims(y1, axis=-1), + self.backend.numpy.expand_dims(x2, axis=-1), + self.backend.numpy.expand_dims(y2, axis=-1), + ], + axis=-1, + ) + bounding_boxes["boxes"] = boxes + + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return self.transform_images( + segmentation_masks, transformation, training=training + ) + + def get_config(self): + base_config = super().get_config() + config = { + "x_factor": self.x_factor, + "y_factor": self.y_factor, + "fill_mode": self.fill_mode, + "interpolation": self.interpolation, + "seed": self.seed, + "fill_value": self.fill_value, + "data_format": self.data_format, + } + return {**base_config, **config} + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_shear_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_shear_test.py new file mode 100644 index 00000000000..9d5592ff491 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_shear_test.py @@ -0,0 +1,200 @@ +import numpy as np +import pytest +from absl.testing import parameterized +from tensorflow import data as tf_data + +import keras +from keras.src import backend +from keras.src import layers +from keras.src import testing +from keras.src.utils import backend_utils + + +class RandomShearTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomShear, + init_kwargs={ + "x_factor": (0.5, 1), + "y_factor": (0.5, 1), + "interpolation": "bilinear", + "fill_mode": "reflect", + "data_format": "channels_last", + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_posterization_inference(self): + seed = 3481 + layer = layers.RandomShear(1, 1) + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_shear_pixel_level(self): + image = np.zeros((1, 5, 5, 3)) + image[0, 1:4, 1:4, :] = 1.0 + image[0, 2, 2, :] = [0.0, 1.0, 0.0] + image = keras.ops.convert_to_tensor(image, dtype="float32") + + data_format = backend.config.image_data_format() + if data_format == "channels_first": + image = keras.ops.transpose(image, (0, 3, 1, 2)) + + shear_layer = layers.RandomShear( + x_factor=(0.2, 0.3), + y_factor=(0.2, 0.3), + interpolation="bilinear", + fill_mode="constant", + fill_value=0.0, + seed=42, + data_format=data_format, + ) + + sheared_image = shear_layer(image) + + if data_format == "channels_first": + sheared_image = keras.ops.transpose(sheared_image, (0, 2, 3, 1)) + + original_pixel = image[0, 2, 2, :] + sheared_pixel = sheared_image[0, 2, 2, :] + self.assertNotAllClose(original_pixel, sheared_pixel) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomShear(1, 1) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() + + @parameterized.named_parameters( + ( + "with_x_shift", + [[1.0, 0.0]], + [[[0.0, 1.0, 3.2, 3.0], [1.2, 4.0, 4.8, 6.0]]], + ), + ( + "with_y_shift", + [[0.0, 1.0]], + [[[2.0, 0.0, 4.0, 0.5], [6.0, 0.0, 8.0, 0.0]]], + ), + ( + "with_xy_shift", + [[1.0, 1.0]], + [[[0.0, 0.0, 3.2, 3.5], [1.2, 0.0, 4.8, 4.5]]], + ), + ) + def test_random_shear_bounding_boxes(self, translation, expected_boxes): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + image_shape = (10, 8, 3) + else: + image_shape = (3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ), + "labels": np.array([[1, 2]]), + } + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + layer = layers.RandomShear( + x_factor=0.5, + y_factor=0.5, + data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + + transformation = { + "shear_factor": backend_utils.convert_tf_tensor( + np.array(translation) + ), + "input_shape": image_shape, + } + output = layer.transform_bounding_boxes( + input_data["bounding_boxes"], + transformation=transformation, + training=True, + ) + + self.assertAllClose(output["boxes"], expected_boxes) + + @parameterized.named_parameters( + ( + "with_x_shift", + [[1.0, 0.0]], + [[[0.0, 1.0, 3.2, 3.0], [1.2, 4.0, 4.8, 6.0]]], + ), + ( + "with_y_shift", + [[0.0, 1.0]], + [[[2.0, 0.0, 4.0, 0.5], [6.0, 0.0, 8.0, 0.0]]], + ), + ( + "with_xy_shift", + [[1.0, 1.0]], + [[[0.0, 0.0, 3.2, 3.5], [1.2, 0.0, 4.8, 4.5]]], + ), + ) + def test_random_shear_tf_data_bounding_boxes( + self, translation, expected_boxes + ): + data_format = backend.config.image_data_format() + if backend.config.image_data_format() == "channels_last": + image_shape = (1, 10, 8, 3) + else: + image_shape = (1, 3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ] + ), + "labels": np.array([[1, 2]]), + } + + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + + ds = tf_data.Dataset.from_tensor_slices(input_data) + layer = layers.RandomShear( + x_factor=0.5, + y_factor=0.5, + data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + + transformation = { + "shear_factor": np.array(translation), + "input_shape": image_shape, + } + + ds = ds.map( + lambda x: layer.transform_bounding_boxes( + x["bounding_boxes"], + transformation=transformation, + training=True, + ) + ) + + output = next(iter(ds)) + expected_boxes = np.array(expected_boxes) + self.assertAllClose(output["boxes"], expected_boxes) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_translation.py b/keras/src/layers/preprocessing/image_preprocessing/random_translation.py index 60e29e0a5b9..1dc69a0db45 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_translation.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_translation.py @@ -215,55 +215,56 @@ def transform_bounding_boxes( transformation, training=True, ): - if backend_utils.in_tf_graph(): - self.backend.set_backend("tensorflow") - - if self.data_format == "channels_first": - height_axis = -2 - width_axis = -1 - else: - height_axis = -3 - width_axis = -2 - - input_height, input_width = ( - transformation["input_shape"][height_axis], - transformation["input_shape"][width_axis], - ) + if training: + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + + if self.data_format == "channels_first": + height_axis = -2 + width_axis = -1 + else: + height_axis = -3 + width_axis = -2 + + input_height, input_width = ( + transformation["input_shape"][height_axis], + transformation["input_shape"][width_axis], + ) - bounding_boxes = convert_format( - bounding_boxes, - source=self.bounding_box_format, - target="xyxy", - height=input_height, - width=input_width, - ) + bounding_boxes = convert_format( + bounding_boxes, + source=self.bounding_box_format, + target="xyxy", + height=input_height, + width=input_width, + ) - translations = transformation["translations"] - transform = self._get_translation_matrix(translations) + translations = transformation["translations"] + transform = self._get_translation_matrix(translations) - w_shift_factor, h_shift_factor = self.get_transformed_x_y( - 0, 0, transform - ) - bounding_boxes = self.get_shifted_bbox( - bounding_boxes, w_shift_factor, h_shift_factor - ) + w_shift_factor, h_shift_factor = self.get_transformed_x_y( + 0, 0, transform + ) + bounding_boxes = self.get_shifted_bbox( + bounding_boxes, w_shift_factor, h_shift_factor + ) - bounding_boxes = clip_to_image_size( - bounding_boxes=bounding_boxes, - height=input_height, - width=input_width, - bounding_box_format="xyxy", - ) + bounding_boxes = clip_to_image_size( + bounding_boxes=bounding_boxes, + height=input_height, + width=input_width, + bounding_box_format="xyxy", + ) - bounding_boxes = convert_format( - bounding_boxes, - source="xyxy", - target=self.bounding_box_format, - height=input_height, - width=input_width, - ) + bounding_boxes = convert_format( + bounding_boxes, + source="xyxy", + target=self.bounding_box_format, + height=input_height, + width=input_width, + ) - self.backend.reset() + self.backend.reset() return bounding_boxes diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_zoom.py b/keras/src/layers/preprocessing/image_preprocessing/random_zoom.py index ec0f03d1c2e..80b29b8e6ad 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_zoom.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_zoom.py @@ -217,84 +217,87 @@ def transform_bounding_boxes( transformation, training=True, ): - if backend_utils.in_tf_graph(): - self.backend.set_backend("tensorflow") - - width_zoom = transformation["width_zoom"] - height_zoom = transformation["height_zoom"] - inputs_shape = transformation["input_shape"] - - if self.data_format == "channels_first": - height = inputs_shape[-2] - width = inputs_shape[-1] - else: - height = inputs_shape[-3] - width = inputs_shape[-2] - - bounding_boxes = convert_format( - bounding_boxes, - source=self.bounding_box_format, - target="xyxy", - height=height, - width=width, - ) + if training: + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + + width_zoom = transformation["width_zoom"] + height_zoom = transformation["height_zoom"] + inputs_shape = transformation["input_shape"] + + if self.data_format == "channels_first": + height = inputs_shape[-2] + width = inputs_shape[-1] + else: + height = inputs_shape[-3] + width = inputs_shape[-2] + + bounding_boxes = convert_format( + bounding_boxes, + source=self.bounding_box_format, + target="xyxy", + height=height, + width=width, + ) - zooms = self.backend.cast( - self.backend.numpy.concatenate([width_zoom, height_zoom], axis=1), - dtype="float32", - ) - transform = self._get_zoom_matrix(zooms, height, width) + zooms = self.backend.cast( + self.backend.numpy.concatenate( + [width_zoom, height_zoom], axis=1 + ), + dtype="float32", + ) + transform = self._get_zoom_matrix(zooms, height, width) - w_start, h_start = self.get_transformed_x_y( - 0, - 0, - transform, - ) + w_start, h_start = self.get_transformed_x_y( + 0, + 0, + transform, + ) - w_end, h_end = self.get_transformed_x_y( - width, - height, - transform, - ) + w_end, h_end = self.get_transformed_x_y( + width, + height, + transform, + ) - bounding_boxes = self.get_clipped_bbox( - bounding_boxes, h_end, h_start, w_end, w_start - ) + bounding_boxes = self.get_clipped_bbox( + bounding_boxes, h_end, h_start, w_end, w_start + ) - height_transformed = h_end - h_start - width_transformed = w_end - w_start + height_transformed = h_end - h_start + width_transformed = w_end - w_start - height_transformed = self.backend.numpy.expand_dims( - height_transformed, -1 - ) - width_transformed = self.backend.numpy.expand_dims( - width_transformed, -1 - ) + height_transformed = self.backend.numpy.expand_dims( + height_transformed, -1 + ) + width_transformed = self.backend.numpy.expand_dims( + width_transformed, -1 + ) - bounding_boxes = convert_format( - bounding_boxes, - source="xyxy", - target="rel_xyxy", - height=height_transformed, - width=width_transformed, - ) + bounding_boxes = convert_format( + bounding_boxes, + source="xyxy", + target="rel_xyxy", + height=height_transformed, + width=width_transformed, + ) - bounding_boxes = clip_to_image_size( - bounding_boxes=bounding_boxes, - height=height_transformed, - width=width_transformed, - bounding_box_format="rel_xyxy", - ) + bounding_boxes = clip_to_image_size( + bounding_boxes=bounding_boxes, + height=height_transformed, + width=width_transformed, + bounding_box_format="rel_xyxy", + ) - bounding_boxes = convert_format( - bounding_boxes, - source="rel_xyxy", - target=self.bounding_box_format, - height=height, - width=width, - ) + bounding_boxes = convert_format( + bounding_boxes, + source="rel_xyxy", + target=self.bounding_box_format, + height=height, + width=width, + ) - self.backend.reset() + self.backend.reset() return bounding_boxes diff --git a/keras/src/layers/preprocessing/image_preprocessing/solarization.py b/keras/src/layers/preprocessing/image_preprocessing/solarization.py index a49d3930f8a..2a8fcee5efa 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/solarization.py +++ b/keras/src/layers/preprocessing/image_preprocessing/solarization.py @@ -156,33 +156,36 @@ def get_random_transformation(self, data, training=True, seed=None): def transform_images(self, images, transformation, training=True): images = self.backend.cast(images, self.compute_dtype) - if transformation is None: - return images - - thresholds = transformation["thresholds"] - additions = transformation["additions"] - images = self._transform_value_range( - images, - original_range=self.value_range, - target_range=(0, 255), - dtype=self.compute_dtype, - ) - results = images + additions - results = self.backend.numpy.clip(results, 0, 255) - results = self.backend.numpy.where( - results < thresholds, results, 255 - results - ) - results = self._transform_value_range( - results, - original_range=(0, 255), - target_range=self.value_range, - dtype=self.compute_dtype, - ) - if results.dtype == images.dtype: - return results - if backend.is_int_dtype(images.dtype): - results = self.backend.numpy.round(results) - return _saturate_cast(results, images.dtype, self.backend) + + if training: + if transformation is None: + return images + + thresholds = transformation["thresholds"] + additions = transformation["additions"] + images = self._transform_value_range( + images, + original_range=self.value_range, + target_range=(0, 255), + dtype=self.compute_dtype, + ) + results = images + additions + results = self.backend.numpy.clip(results, 0, 255) + results = self.backend.numpy.where( + results < thresholds, results, 255 - results + ) + results = self._transform_value_range( + results, + original_range=(0, 255), + target_range=self.value_range, + dtype=self.compute_dtype, + ) + if results.dtype == images.dtype: + return results + if backend.is_int_dtype(images.dtype): + results = self.backend.numpy.round(results) + return _saturate_cast(results, images.dtype, self.backend) + return images def transform_labels(self, labels, transformation, training=True): return labels diff --git a/keras/src/metrics/iou_metrics.py b/keras/src/metrics/iou_metrics.py index a3c9a904b06..65c84e591b9 100644 --- a/keras/src/metrics/iou_metrics.py +++ b/keras/src/metrics/iou_metrics.py @@ -702,7 +702,6 @@ class apply. associated label. axis: (Optional) The dimension containing the logits. Defaults to `-1`. - Example: Example: diff --git a/keras/src/models/model.py b/keras/src/models/model.py index 832e0b35b36..46f10307654 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -470,15 +470,12 @@ def export( ): """Export the model as an artifact for inference. - **Note:** This feature is currently supported only with TensorFlow and - JAX backends. - **Note:** Currently, only `format="tf_saved_model"` is supported. - Args: filepath: `str` or `pathlib.Path` object. The path to save the artifact. - format: `str`. The export format. Supported value: - `"tf_saved_model"`. Defaults to `"tf_saved_model"`. + format: `str`. The export format. Supported values: + `"tf_saved_model"` and `"onnx"`. Defaults to + `"tf_saved_model"`. verbose: `bool`. Whether to print a message during export. Defaults to `True`. input_signature: Optional. Specifies the shape and dtype of the @@ -487,7 +484,7 @@ def export( not provided, it will be automatically computed. Defaults to `None`. **kwargs: Additional keyword arguments: - - Specific to the JAX backend: + - Specific to the JAX backend and `format="tf_saved_model"`: - `is_static`: Optional `bool`. Indicates whether `fn` is static. Set to `False` if `fn` involves state updates (e.g., RNG seeds and counters). @@ -498,7 +495,12 @@ def export( If `native_serialization` and `polymorphic_shapes` are not provided, they will be automatically computed. - Example: + **Note:** This feature is currently supported only with TensorFlow, JAX + and Torch backends. + + Examples: + + Here's how to export a TensorFlow SavedModel for inference. ```python # Export the model as a TensorFlow SavedModel artifact @@ -508,10 +510,25 @@ def export( reloaded_artifact = tf.saved_model.load("path/to/location") predictions = reloaded_artifact.serve(input_data) ``` + + Here's how to export an ONNX for inference. + + ```python + # Export the model as a ONNX artifact + model.export("path/to/location", format="onnx") + + # Load the artifact in a different process/environment + ort_session = onnxruntime.InferenceSession("path/to/location") + ort_inputs = { + k.name: v for k, v in zip(ort_session.get_inputs(), input_data) + } + predictions = ort_session.run(None, ort_inputs) + ``` """ - from keras.src.export import export_lib + from keras.src.export import export_onnx + from keras.src.export import export_saved_model - available_formats = ("tf_saved_model",) + available_formats = ("tf_saved_model", "onnx") if format not in available_formats: raise ValueError( f"Unrecognized format={format}. Supported formats are: " @@ -519,7 +536,15 @@ def export( ) if format == "tf_saved_model": - export_lib.export_saved_model( + export_saved_model( + self, + filepath, + verbose, + input_signature=input_signature, + **kwargs, + ) + elif format == "onnx": + export_onnx( self, filepath, verbose, diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index de7fd98e9db..eb83cad4235 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -1219,33 +1219,74 @@ def test_functional_deeply_nested_outputs_struct_losses(self): ) self.assertListEqual(hist_keys, ref_keys) + @parameterized.named_parameters( + ("tf_saved_model", "tf_saved_model"), + ("onnx", "onnx"), + ) @pytest.mark.skipif( - backend.backend() not in ("tensorflow", "jax"), + backend.backend() not in ("tensorflow", "jax", "torch"), reason=( - "Currently, `Model.export` only supports the tensorflow and jax" - " backends." + "Currently, `Model.export` only supports the tensorflow, jax and " + "torch backends." ), ) @pytest.mark.skipif( testing.jax_uses_gpu(), reason="Leads to core dumps on CI" ) - def test_export(self): - import tensorflow as tf + def test_export(self, export_format): + if export_format == "tf_saved_model" and testing.torch_uses_gpu(): + self.skipTest("Leads to core dumps on CI") temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") model = _get_model() - x1 = np.random.rand(2, 3) - x2 = np.random.rand(2, 3) + x1 = np.random.rand(1, 3).astype("float32") + x2 = np.random.rand(1, 3).astype("float32") ref_output = model([x1, x2]) - model.export(temp_filepath) - revived_model = tf.saved_model.load(temp_filepath) - self.assertAllClose(ref_output, revived_model.serve([x1, x2])) + model.export(temp_filepath, format=export_format) - # Test with a different batch size - revived_model.serve( - [np.concatenate([x1, x1], axis=0), np.concatenate([x2, x2], axis=0)] - ) + if export_format == "tf_saved_model": + import tensorflow as tf + + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_model.serve([x1, x2])) + + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + revived_model.serve( + [ + np.concatenate([x1, x1], axis=0), + np.concatenate([x2, x2], axis=0), + ] + ) + elif export_format == "onnx": + import onnxruntime + + ort_session = onnxruntime.InferenceSession(temp_filepath) + ort_inputs = { + k.name: v for k, v in zip(ort_session.get_inputs(), [x1, x2]) + } + self.assertAllClose( + ref_output, ort_session.run(None, ort_inputs)[0] + ) + + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + ort_inputs = { + k.name: v + for k, v in zip( + ort_session.get_inputs(), + [ + np.concatenate([x1, x1], axis=0), + np.concatenate([x2, x2], axis=0), + ], + ) + } + ort_session.run(None, ort_inputs) def test_export_error(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") @@ -1256,9 +1297,12 @@ def test_export_error(self): model.export(temp_filepath, format="bad_format") # Bad backend - if backend.backend() not in ("tensorflow", "jax"): + if backend.backend() not in ("tensorflow", "jax", "torch"): with self.assertRaisesRegex( NotImplementedError, - "The export API is only compatible with JAX and TF backends.", + ( + r"`export_saved_model` only currently supports the " + r"tensorflow, jax and torch backends." + ), ): - model.export(temp_filepath) + model.export(temp_filepath, format="tf_saved_model") diff --git a/keras/src/optimizers/base_optimizer.py b/keras/src/optimizers/base_optimizer.py index 8717192fa84..57833afadc7 100644 --- a/keras/src/optimizers/base_optimizer.py +++ b/keras/src/optimizers/base_optimizer.py @@ -245,9 +245,29 @@ def add_variable( shape, initializer="zeros", dtype=None, - aggregation="mean", + aggregation="none", name=None, ): + """Add a variable to the optimizer. + + Args: + shape: Shape tuple for the variable. Must be fully-defined + (no `None` entries). + initializer: Initializer object to use to populate the initial + variable value, or string name of a built-in initializer + (e.g. `"random_normal"`). Defaults to `"zeros"`. + dtype: Dtype of the variable to create, e.g. `"float32"`. If + unspecified, defaults to the `keras.backend.floatx()`. + aggregation: Optional string, one of `None`, `"none"`, `"mean"`, + `"sum"` or `"only_first_replica"`. Annotates the variable with + the type of multi-replica aggregation to be used for this + variable when writing custom data parallel training loops. + Defaults to `"none"`. + name: String name of the variable. Useful for debugging purposes. + + Returns: + An optimizer variable, in the format of `keras.Variable`. + """ self._check_super_called() initializer = initializers.get(initializer) with backend.name_scope(self.name, caller=self): @@ -265,8 +285,27 @@ def add_variable( def add_variable_from_reference( self, reference_variable, name=None, initializer="zeros" ): - """Add an all-zeros variable with the shape and dtype of a reference - variable. + """Add an optimizer variable from the model variable. + + Create an optimizer variable based on the information of model variable. + For example, in SGD optimizer momemtum, for each model variable, a + corresponding momemtum variable is created of the same shape and dtype. + + Args: + reference_variable: `keras.Variable`. The corresponding model + variable to the optimizer variable to be created. + name: Optional string. The name prefix of the optimizer variable to + be created. If not provided, it will be set to `"var"`. The + variable name will follow the pattern + `{variable_name}_{reference_variable.name}`, + e.g., `momemtum/dense_1`. Defaults to `None`. + initializer: Initializer object to use to populate the initial + variable value, or string name of a built-in initializer + (e.g. `"random_normal"`). If unspecified, defaults to + `"zeros"`. + + Returns: + An optimizer variable, in the format of `keras.Variable`. """ name = name or "var" if hasattr(reference_variable, "path"): diff --git a/keras/src/utils/backend_utils_test.py b/keras/src/utils/backend_utils_test.py index 6255f0d7bd7..24883104601 100644 --- a/keras/src/utils/backend_utils_test.py +++ b/keras/src/utils/backend_utils_test.py @@ -15,7 +15,7 @@ class BackendUtilsTest(testing.TestCase): ) def test_dynamic_backend(self, name): dynamic_backend = backend_utils.DynamicBackend() - x = np.random.uniform(size=[1, 2, 3]) + x = np.random.uniform(size=[1, 2, 3]).astype("float32") if name == "numpy": dynamic_backend.set_backend(name) diff --git a/keras/src/utils/model_visualization.py b/keras/src/utils/model_visualization.py index a417a61f0bd..786a72b8b6f 100644 --- a/keras/src/utils/model_visualization.py +++ b/keras/src/utils/model_visualization.py @@ -8,16 +8,15 @@ from keras.src.utils import io_utils try: - # pydot-ng is a fork of pydot that is better maintained. - import pydot_ng as pydot + import pydot except ImportError: - # pydotplus is an improved version of pydot + # pydot_ng and pydotplus are older forks of pydot + # which may still be used by some users try: - import pydotplus as pydot + import pydot_ng as pydot except ImportError: - # Fall back on pydot if necessary. try: - import pydot + import pydotplus as pydot except ImportError: pydot = None diff --git a/keras/src/utils/module_utils.py b/keras/src/utils/module_utils.py index a0a218a1512..190bc8dc72f 100644 --- a/keras/src/utils/module_utils.py +++ b/keras/src/utils/module_utils.py @@ -2,10 +2,13 @@ class LazyModule: - def __init__(self, name, pip_name=None): + def __init__(self, name, pip_name=None, import_error_msg=None): self.name = name - pip_name = pip_name or name - self.pip_name = pip_name + self.pip_name = pip_name or name + self.import_error_msg = import_error_msg or ( + f"This requires the {self.name} module. " + f"You can install it via `pip install {self.pip_name}`" + ) self.module = None self._available = None @@ -23,10 +26,7 @@ def initialize(self): try: self.module = importlib.import_module(self.name) except ImportError: - raise ImportError( - f"This requires the {self.name} module. " - f"You can install it via `pip install {self.pip_name}`" - ) + raise ImportError(self.import_error_msg) def __getattr__(self, name): if name == "_api_export_path": @@ -45,5 +45,17 @@ def __repr__(self): scipy = LazyModule("scipy") jax = LazyModule("jax") torchvision = LazyModule("torchvision") +torch_xla = LazyModule( + "torch_xla", + import_error_msg=( + "This requires the torch_xla module. You can install it via " + "`pip install torch-xla`. Additionally, you may need to update " + "LD_LIBRARY_PATH if necessary. Torch XLA builds a shared library, " + "_XLAC.so, which needs to link to the version of Python it was built " + "with. Use the following command to update LD_LIBRARY_PATH: " + "`export LD_LIBRARY_PATH=/lib:$LD_LIBRARY_PATH`" + ), +) optree = LazyModule("optree") dmtree = LazyModule("tree") +tf2onnx = LazyModule("tf2onnx") diff --git a/keras/src/version.py b/keras/src/version.py index 0a3be890297..db523fbaa13 100644 --- a/keras/src/version.py +++ b/keras/src/version.py @@ -1,7 +1,7 @@ from keras.src.api_export import keras_export # Unique source of truth for the version number. -__version__ = "3.7.0" +__version__ = "3.8.0" @keras_export("keras.version") diff --git a/requirements-common.txt b/requirements-common.txt index 2d1ec92d911..51c682f9ef4 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -20,3 +20,5 @@ packaging # for tree_test.py dm_tree coverage!=7.6.5 # 7.6.5 breaks CI +# for onnx_test.py +onnxruntime diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 34cde7ba8e3..7b1d2166f63 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -1,10 +1,12 @@ # Tensorflow cpu-only version (needed for testing). -tensorflow-cpu~=2.18.0 # Pin to TF 2.16 +tensorflow-cpu~=2.18.0 +tf2onnx # Torch cpu-only version (needed for testing). --extra-index-url https://download.pytorch.org/whl/cpu torch>=2.1.0 torchvision>=0.16.0 +torch-xla # Jax with cuda support. # TODO: Higher version breaks CI. diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt index ded589258c8..fed601f658f 100644 --- a/requirements-tensorflow-cuda.txt +++ b/requirements-tensorflow-cuda.txt @@ -1,10 +1,12 @@ # Tensorflow with cuda support. -tensorflow[and-cuda]~=2.18.0 # Pin to TF 2.16 +tensorflow[and-cuda]~=2.18.0 +tf2onnx # Torch cpu-only version (needed for testing). --extra-index-url https://download.pytorch.org/whl/cpu torch>=2.1.0 torchvision>=0.16.0 +torch-xla # Jax cpu-only version (needed for testing). jax[cpu] diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt index a325755201f..d165faa1628 100644 --- a/requirements-torch-cuda.txt +++ b/requirements-torch-cuda.txt @@ -1,10 +1,12 @@ # Tensorflow cpu-only version (needed for testing). -tensorflow-cpu~=2.18.0 # Pin to TF 2.16 +tensorflow-cpu~=2.18.0 +tf2onnx # Torch with cuda support. --extra-index-url https://download.pytorch.org/whl/cu121 torch==2.5.1+cu121 torchvision==0.20.1+cu121 +torch-xla # Jax cpu-only version (needed for testing). jax[cpu] diff --git a/requirements.txt b/requirements.txt index cecfc93a2b6..0973be4969a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,15 @@ # Tensorflow. -tensorflow-cpu~=2.18.0;sys_platform != 'darwin' # Pin to TF 2.16 +tensorflow-cpu~=2.18.0;sys_platform != 'darwin' tensorflow~=2.18.0;sys_platform == 'darwin' tf_keras +tf2onnx # Torch. # TODO: Pin to < 2.3.0 (GitHub issue #19602) --extra-index-url https://download.pytorch.org/whl/cpu torch>=2.1.0 torchvision>=0.16.0 +torch-xla # Jax. jax[cpu]