Skip to content

Commit

Permalink
TYP: add to dual directory (#538)
Browse files Browse the repository at this point in the history
  • Loading branch information
attack68 authored Dec 6, 2024
1 parent 4957040 commit 8a503a1
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 35 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ ignore = [
files = [
"python/rateslib/calendars/**/*.py",
"python/rateslib/fx/**/*.py",
"python/rateslib/dual/**/*.py"
]
strict = true

Expand Down
34 changes: 24 additions & 10 deletions python/rateslib/dual/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Contact rateslib at gmail.com if this code is observed outside its intended sphere.


def set_order(val, order):
def set_order(val: Number, order: int) -> Number:
"""
Changes the order of a :class:`Dual` or :class:`Dual2` leaving floats and ints
unchanged.
Expand Down Expand Up @@ -49,7 +49,9 @@ def set_order(val, order):
return val


def set_order_convert(val, order, tag, vars_from=None):
def set_order_convert(
val: Number, order: int, tag: list[str] | None, vars_from: Dual | Dual2 | None = None
) -> Number:
"""
Convert a float, :class:`Dual` or :class:`Dual2` type to a specified alternate type.
Expand Down Expand Up @@ -87,7 +89,14 @@ def set_order_convert(val, order, tag, vars_from=None):
return set_order(val, order)


def gradient(dual, vars: list[str] | None = None, order: int = 1, keep_manifold: bool = False):
def gradient(
dual: Dual | Dual2 | Variable,
vars: list[str] | None = None,
order: int = 1,
keep_manifold: bool = False,
) -> (
np.ndarray[tuple[int], np.dtype[np.float64]] | np.ndarray[tuple[int, int], np.dtype[np.float64]]
):
"""
Return derivatives of a dual number.
Expand Down Expand Up @@ -135,7 +144,7 @@ def gradient(dual, vars: list[str] | None = None, order: int = 1, keep_manifold:
raise ValueError("`order` must be in {1, 2} for gradient calculation.")


def dual_exp(x):
def dual_exp(x: DualTypes) -> Number:
"""
Calculate the exponential value of a regular int or float or a dual number.
Expand All @@ -153,7 +162,7 @@ def dual_exp(x):
return math.exp(x)


def dual_log(x, base=None):
def dual_log(x: DualTypes, base: int | None = None) -> Number:
"""
Calculate the logarithm of a regular int or float or a dual number.
Expand All @@ -180,7 +189,7 @@ def dual_log(x, base=None):
return math.log(x, base)


def dual_norm_pdf(x):
def dual_norm_pdf(x: DualTypes) -> Number:
"""
Return the standard normal probability density function.
Expand All @@ -195,7 +204,7 @@ def dual_norm_pdf(x):
return dual_exp(-0.5 * x**2) / math.sqrt(2.0 * math.pi)


def dual_norm_cdf(x):
def dual_norm_cdf(x: DualTypes) -> Number:
"""
Return the cumulative standard normal distribution for given value.
Expand All @@ -213,7 +222,7 @@ def dual_norm_cdf(x):
return NormalDist().cdf(x)


def dual_inv_norm_cdf(x):
def dual_inv_norm_cdf(x: DualTypes) -> Number:
"""
Return the inverse cumulative standard normal distribution for given value.
Expand All @@ -231,7 +240,12 @@ def dual_inv_norm_cdf(x):
return NormalDist().inv_cdf(x)


def dual_solve(A, b, allow_lsq=False, types=(Dual, Dual)):
def dual_solve(
A: np.ndarray[tuple[int, int], np.dtype[np.object_]],
b: np.ndarray[tuple[int], np.dtype[np.object_]],
allow_lsq: bool = False,
types: tuple[Number, Number] = (Dual, Dual),
) -> np.ndarray[tuple[int], np.dtype[np.object_]]:
"""
Solve a linear system of equations involving dual number data types.
Expand Down Expand Up @@ -289,7 +303,7 @@ def dual_solve(A, b, allow_lsq=False, types=(Dual, Dual)):
return np.array(out)[:, None]


def _get_adorder(order: int):
def _get_adorder(order: int) -> ADOrder:
if order == 1:
return ADOrder.One
elif order == 0:
Expand Down
51 changes: 26 additions & 25 deletions python/rateslib/dual/variable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import math
from typing import Any

import numpy as np

Expand Down Expand Up @@ -43,17 +44,17 @@ def __init__(
self,
real: float,
vars: tuple[str, ...] = (),
dual: np.ndarray | NoInput = NoInput(0),
dual: np.ndarray[tuple[int, int], np.dtype[np.float64]] | NoInput = NoInput(0),
):
self.real: float = float(real)
self.vars: tuple[str, ...] = tuple(vars)
n = len(self.vars)
if dual is NoInput.blank or len(dual) == 0:
self.dual: np.ndarray = np.ones(n)
if isinstance(dual, NoInput) or len(dual) == 0:
self.dual: np.ndarray = np.ones(n, dtype=np.float64)
else:
self.dual = np.asarray(dual.copy())

def _to_dual_type(self, order):
def _to_dual_type(self, order: int) -> Dual | Dual2:
if order == 1:
return Dual(self.real, vars=self.vars, dual=self.dual)
elif order == 2:
Expand All @@ -63,7 +64,7 @@ def _to_dual_type(self, order):
f"`Variable` can only be converted with `order` in [1, 2], got order: {order}."
)

def __eq__(self, argument):
def __eq__(self, argument: Any) -> bool:
"""
Compare an argument with a Variable for equality.
This does not account for variable ordering.
Expand All @@ -74,19 +75,19 @@ def __eq__(self, argument):
return self.__eq_coeffs__(argument, PRECISION)
return False

def __lt__(self, other):
def __lt__(self, other: Any) -> bool:
return self.real.__lt__(other)

def __le__(self, other):
def __le__(self, other: Any) -> bool:
return self.real.__le__(other)

def __gt__(self, other):
def __gt__(self, other: Any) -> bool:
return self.real.__gt__(other)

def __ge__(self, other):
def __ge__(self, other: Any) -> bool:
return self.real.__ge__(other)

def __eq_coeffs__(self, argument, precision):
def __eq_coeffs__(self, argument: Dual | Dual2 | Variable, precision: float) -> bool:
"""Compare the coefficients of two dual array numbers for equality."""
return not (
not math.isclose(self.real, argument.real, abs_tol=precision)
Expand All @@ -99,10 +100,10 @@ def __eq_coeffs__(self, argument, precision):
# and https://github.com/PyO3/pyo3/discussions/3911
# return self.real

def __neg__(self):
def __neg__(self) -> Variable:
return Variable(-self.real, vars=self.vars, dual=-self.dual)

def __add__(self, other):
def __add__(self, other: Dual | Dual2 | float | int | Variable) -> Dual | Dual2:
if isinstance(other, Variable):
_1 = self._to_dual_type(defaults._global_ad_order)
_2 = other._to_dual_type(defaults._global_ad_order)
Expand All @@ -118,16 +119,16 @@ def __add__(self, other):
else:
raise TypeError(f"No operation defined between `Variable` and type: `{type(other)}`")

def __radd__(self, other):
def __radd__(self, other: Dual | Dual2 | float | int | Variable) -> Dual | Dual2:
return self.__add__(other)

def __rsub__(self, other):
def __rsub__(self, other: Dual | Dual2 | float | int | Variable) -> Dual | Dual2:
return (self.__neg__()).__add__(other)

def __sub__(self, other):
def __sub__(self, other: Dual | Dual2 | float | int | Variable) -> Dual | Dual2:
return self.__add__(other.__neg__())

def __mul__(self, other):
def __mul__(self, other: Dual | Dual2 | float | int | Variable) -> Dual | Dual2:
if isinstance(other, Variable):
_1 = self._to_dual_type(defaults._global_ad_order)
_2 = other._to_dual_type(defaults._global_ad_order)
Expand All @@ -143,10 +144,10 @@ def __mul__(self, other):
else:
raise TypeError(f"No operation defined between `Variable` and type: `{type(other)}`")

def __rmul__(self, other):
def __rmul__(self, other: Dual | Dual2 | float | int | Variable) -> Dual | Dual2:
return self.__mul__(other)

def __truediv__(self, other):
def __truediv__(self, other: Dual | Dual2 | float | int | Variable) -> Dual | Dual2:
if isinstance(other, Variable):
_1 = self._to_dual_type(defaults._global_ad_order)
_2 = other._to_dual_type(defaults._global_ad_order)
Expand All @@ -162,7 +163,7 @@ def __truediv__(self, other):
else:
raise TypeError(f"No operation defined between `Variable` and type: `{type(other)}`")

def __rtruediv__(self, other):
def __rtruediv__(self, other: Dual | Dual2 | float | int | Variable) -> Dual | Dual2:
if isinstance(other, Variable):
# cannot reach this line
raise TypeError("Impossible line execution - please report issue.") # pragma: no cover
Expand All @@ -178,27 +179,27 @@ def __rtruediv__(self, other):
else:
raise TypeError(f"No operation defined between `Variable` and type: `{type(other)}`")

def __exp__(self):
def __exp__(self) -> Dual | Dual2:
_1 = self._to_dual_type(defaults._global_ad_order)
return _1.__exp__()

def __log__(self):
def __log__(self) -> Dual | Dual2:
_1 = self._to_dual_type(defaults._global_ad_order)
return _1.__log__()

def __norm_cdf__(self):
def __norm_cdf__(self) -> Dual | Dual2:
_1 = self._to_dual_type(defaults._global_ad_order)
return _1.__norm_cdf__()

def __norm_inv_cdf__(self):
def __norm_inv_cdf__(self) -> Dual | Dual2:
_1 = self._to_dual_type(defaults._global_ad_order)
return _1.__norm_inv_cdf__()

def __pow__(self, exponent):
def __pow__(self, exponent: float) -> Dual | Dual2:
_1 = self._to_dual_type(defaults._global_ad_order)
return _1.__pow__(exponent)

def __repr__(self):
def __repr__(self) -> str:
a = ", ".join(self.vars[:3])
b = ", ".join([str(_) for _ in self.dual[:3]])
if len(self.vars) > 3:
Expand Down

0 comments on commit 8a503a1

Please sign in to comment.