diff --git a/tf_quant_finance/__init__.py b/tf_quant_finance/__init__.py index 9be999d33..b438ca0b7 100644 --- a/tf_quant_finance/__init__.py +++ b/tf_quant_finance/__init__.py @@ -28,7 +28,7 @@ # Ensure Python 3 is used. -def _check_py_version(): +def _check_py_version() -> None: if sys.version_info[0] < 3: raise Exception("Please use Python 3. Python 2 is not supported.") @@ -36,7 +36,7 @@ def _check_py_version(): # Ensure TensorFlow is importable and its version is sufficiently recent. This # needs to happen before anything else, since the imports below will try to # import tensorflow, too. -def _ensure_tf_install(): # pylint: disable=g-statement-before-imports +def _ensure_tf_install() -> None: # pylint: disable=g-statement-before-imports """Attempt to import tensorflow, and ensure its version is sufficient. Raises: diff --git a/tf_quant_finance/utils/dataclass.py b/tf_quant_finance/utils/dataclass.py index e5949dcbd..e516bf4ef 100644 --- a/tf_quant_finance/utils/dataclass.py +++ b/tf_quant_finance/utils/dataclass.py @@ -76,13 +76,13 @@ class are treated as ordered in the same order as they appear in the class cls = attr.s(cls, auto_attribs=True) # Define __iter__ and __len__ method to ensure tf.while_loop compatibility - def __iter__(self): # pylint: disable=invalid-name + def __iter__(self) -> None: # pylint: disable=invalid-name # Note that self.__attrs_attrs__ is a tuple so the iteration order is fixed for item in self.__attrs_attrs__: name = item.name yield getattr(self, name) - def __len__(self): # pylint: disable=invalid-name + def __len__(self) -> int: # pylint: disable=invalid-name return len(self.__attrs_attrs__) cls.__len__ = __len__ diff --git a/tf_quant_finance/utils/dataclass_test.py b/tf_quant_finance/utils/dataclass_test.py index c60e5fb2a..09a8c9cd6 100644 --- a/tf_quant_finance/utils/dataclass_test.py +++ b/tf_quant_finance/utils/dataclass_test.py @@ -14,6 +14,7 @@ # limitations under the License. """Tests for the Coord decorator.""" +from typing import Tuple import tensorflow.compat.v2 as tf import tf_quant_finance as tff @@ -31,9 +32,9 @@ class Coords: @tf.function def fn(start_coords: Coords) -> Coords: - def cond(it, _): + def cond(it, _) -> bool: return it < 10 - def body(it, coords): + def body(it, coords) -> Tuple(int, Coords): return it + 1, Coords(x=coords.x + 1, y=coords.y + 2) return tf.while_loop(cond, body, loop_vars=(0, start_coords))[1] @@ -47,7 +48,7 @@ def body(it, coords): with self.subTest('SecondValue'): self.assertEqual(end_coords_eval.y, 20) - def test_docstring_preservation(self): + def test_docstring_preservation(self) -> None: @tff.utils.dataclass class Coords: """A coordinate grid.""" diff --git a/tf_quant_finance/utils/tf_functions_test.py b/tf_quant_finance/utils/tf_functions_test.py index fd39d0f9f..6e626094b 100644 --- a/tf_quant_finance/utils/tf_functions_test.py +++ b/tf_quant_finance/utils/tf_functions_test.py @@ -19,7 +19,6 @@ import tensorflow.compat.v2 as tf import tf_quant_finance as tff -from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import @dataclasses.dataclass @@ -33,7 +32,7 @@ class TfFunctionsTest(tf.test.TestCase): def _assert_keys_values(self, iterator, expected_keys=None, - expected_values=None): + expected_values=None) -> None: key_lists, values = zip(*list(iterator)) keys = ['_'.join(k) for k in key_lists] if expected_keys is not None and expected_values is not None: @@ -49,33 +48,33 @@ def _assert_keys_values(self, if expected_values is not None: self.assertSameElements(values, expected_values) - def test_non_nested(self): + def test_non_nested(self) -> None: d = {'a': 1, 'b': 2} iterator = tff.utils.iterate_nested(d) self._assert_keys_values( iterator, expected_keys=['a', 'b'], expected_values=[1, 2]) - def test_empty(self): + def test_empty(self) -> None: items = [] for item in tff.utils.iterate_nested({}): items.append(item) self.assertEmpty(items) - def test_array_values(self): + def test_array_values(self) -> None: d = {'a': [1, 2, 3], 'b': {'c': [4, 5]}} self._assert_keys_values( tff.utils.iterate_nested(d), expected_keys=['a', 'b_c'], expected_values=[[1, 2, 3], [4, 5]]) - def test_nested(self): + def test_nested(self) -> None: nested_dict = {'a': 1, 'b': [2, 3, 4], 'c': {'d': 8}} self._assert_keys_values( tff.utils.iterate_nested(nested_dict), expected_keys=['a', 'b', 'c_d'], expected_values=[1, [2, 3, 4], 8]) - def test_dataclass(self): + def test_dataclass(self) -> None: d = { 'a': { 'b': {