Skip to content

Commit

Permalink
TYP: multi ccy derivs (#625)
Browse files Browse the repository at this point in the history
Co-authored-by: JHM Darbyshire (win11) <[email protected]>
  • Loading branch information
attack68 and attack68 authored Jan 13, 2025
1 parent 4ab77f0 commit cd62a1e
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 38 deletions.
59 changes: 31 additions & 28 deletions python/rateslib/instruments/rates/multi_currency.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def npv(
"""
self._set_pricing_mid(curves, solver, fx)

curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
self.curves,
solver,
curves,
Expand All @@ -139,10 +139,10 @@ def npv(
"Must have some FX information to price FXExchange, either `fx` or "
"`solver` containing an FX object.",
)
if not isinstance(fx_, FXRates | FXForwards):
elif not isinstance(fx_, FXRates | FXForwards):
# force base_ leg1 currency to be converted consistent.
leg1_npv = self.leg1.npv(curves[0], curves[1], fx_, base_, local)
leg2_npv = self.leg2.npv(curves[2], curves[3], 1.0, base_, local)
leg1_npv = self.leg1.npv(curves_[0], curves_[1], fx_, base_, local)
leg2_npv = self.leg2.npv(curves_[2], curves_[3], 1.0, base_, local)
warnings.warn(
"When valuing multi-currency derivatives it not best practice to "
"supply `fx` as numeric.\nYour input:\n"
Expand All @@ -155,8 +155,8 @@ def npv(
UserWarning,
)
else:
leg1_npv = self.leg1.npv(curves[0], curves[1], fx_, base_, local)
leg2_npv = self.leg2.npv(curves[2], curves[3], fx_, base_, local)
leg1_npv = self.leg1.npv(curves_[0], curves_[1], fx_, base_, local)
leg2_npv = self.leg2.npv(curves_[2], curves_[3], fx_, base_, local)

if local:
return {
Expand All @@ -178,7 +178,7 @@ def cashflows(
For arguments see :meth:`BaseMixin.npv<rateslib.instruments.BaseMixin.cashflows>`
"""
self._set_pricing_mid(curves, solver, fx)
curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
self.curves,
solver,
curves,
Expand All @@ -187,8 +187,8 @@ def cashflows(
NoInput(0),
)
seq = [
self.leg1.cashflows(curves[0], curves[1], fx_, base_),
self.leg2.cashflows(curves[2], curves[3], fx_, base_),
self.leg1.cashflows(curves_[0], curves_[1], fx_, base_),
self.leg2.cashflows(curves_[2], curves_[3], fx_, base_),
]
_ = DataFrame.from_records(seq)
_.index = MultiIndex.from_tuples([("leg1", 0), ("leg2", 0)])
Expand All @@ -206,7 +206,7 @@ def rate(
For arguments see :meth:`BaseMixin.rate<rateslib.instruments.BaseMixin.rate>`
"""
curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
self.curves,
solver,
curves,
Expand All @@ -215,7 +215,7 @@ def rate(
self.leg1.currency,
)
if isinstance(fx_, FXRates | FXForwards):
imm_fx = fx_.rate(self.pair)
imm_fx: FX_ = fx_.rate(self.pair)
else:
imm_fx = fx_

Expand All @@ -224,7 +224,7 @@ def rate(
"`fx` must be supplied to price FXExchange object.\n"
"Note: it can be attached to and then gotten from a Solver.",
)
_ = forward_fx(self.settlement, curves[1], curves[3], imm_fx)
_ = forward_fx(self.settlement, curves_[1], curves_[3], imm_fx)
return _

def delta(self, *args: Any, **kwargs: Any) -> DataFrame:
Expand Down Expand Up @@ -328,6 +328,9 @@ class XCS(BaseDerivative):
Required keyword arguments for :class:`~rateslib.instruments.BaseDerivative`.
"""

leg1: FixedLeg | FloatLeg
leg2: FixedLeg | FloatLeg | FloatLegMtm | FixedLegMtm

def __init__(
self,
*args: Any,
Expand Down Expand Up @@ -734,7 +737,7 @@ def rate(

return _ if _is_float_tgt_leg else _ * 0.01

def spread(self, *args, **kwargs):
def spread(self, *args: Any, **kwargs: Any) -> DualTypes:
"""
Alias for :meth:`~rateslib.instruments.BaseXCS.rate`
"""
Expand All @@ -747,7 +750,7 @@ def cashflows(
fx: FXForwards | NoInput = NoInput(0),
base: str_ = NoInput(0),
):
curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
self.curves,
solver,
curves,
Expand All @@ -757,13 +760,13 @@ def cashflows(
)

if self._is_unpriced:
self._set_pricing_mid(curves, solver, fx_)
self._set_pricing_mid(curves_, solver, fx_)

self._set_fx_fixings(fx_)
if self._is_mtm:
self.leg2._do_not_repeat_set_periods = True

ret = super().cashflows(curves, solver, fx_, base_)
ret = super().cashflows(curves_, solver, fx_, base_)
if self._is_mtm:
self.leg2._do_not_repeat_set_periods = False # reset the mtm calc
return ret
Expand Down Expand Up @@ -810,7 +813,7 @@ def fixings_table(
-------
DataFrame
"""
curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
self.curves,
solver,
curves,
Expand All @@ -821,8 +824,8 @@ def fixings_table(

try:
df1 = self.leg1.fixings_table(
curve=curves[0],
disc_curve=curves[1],
curve=curves_[0],
disc_curve=curves_[1],
fx=fx_,
base=base_,
approximate=approximate,
Expand All @@ -835,8 +838,8 @@ def fixings_table(

try:
df2 = self.leg2.fixings_table(
curve=curves[2],
disc_curve=curves[3],
curve=curves_[2],
disc_curve=curves_[3],
fx=fx_,
base=base_,
approximate=approximate,
Expand Down Expand Up @@ -1146,7 +1149,7 @@ def _set_pricing_mid(
curves: Curves_ = NoInput(0),
solver: Solver_ = NoInput(0),
fx: FXForwards | NoInput = NoInput(0),
):
) -> None:
# This function ASSUMES that the instrument is unpriced, i.e. all of
# split_notional, fx_fixing and points have been initialised as None.

Expand All @@ -1161,7 +1164,7 @@ def rate(
solver: Solver_ = NoInput(0),
fx: FXForwards | NoInput = NoInput(0),
fixed_rate: bool = False,
):
) -> DualTypes:
"""
Return the mid-market pricing parameter of the FXSwapS.
Expand All @@ -1188,7 +1191,7 @@ def rate(
-------
float, Dual or Dual2
"""
curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
self.curves,
solver,
curves,
Expand All @@ -1197,11 +1200,11 @@ def rate(
self.leg1.currency,
)
# set the split notional from the curve if not available
self._set_split_notional(curve=curves[1])
self._set_split_notional(curve=curves_[1])
# then we will set the fx_fixing and leg2 initial notional.

# self._set_fx_fixings(fx) # this will be done by super().rate()
leg2_fixed_rate = super().rate(curves, solver, fx_, leg=2)
leg2_fixed_rate = super().rate(curves_, solver, fx_, leg=2)

if fixed_rate:
return leg2_fixed_rate
Expand All @@ -1218,8 +1221,8 @@ def cashflows(
solver: Solver_ = NoInput(0),
fx: FXForwards | NoInput = NoInput(0),
base: str_ = NoInput(0),
):
) -> DataFrame:
if self._is_unpriced:
self._set_pricing_mid(curves, solver, fx)
ret = super().cashflows(curves, solver, fx, base)
ret: DataFrame = super().cashflows(curves, solver, fx, base)
return ret
30 changes: 20 additions & 10 deletions python/rateslib/periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,17 @@
from rateslib.splines import evaluate

if TYPE_CHECKING:
from rateslib.typing import FX_, CalInput, CalTypes, Curve_, CurveOption_, DualTypes, Number
from rateslib.typing import (
FX_,
NPV,
CalInput,
CalTypes,
Curve_,
CurveOption_,
DualTypes,
Number,
str_,
)

# Licence: Creative Commons - Attribution-NonCommercial-NoDerivatives 4.0 International
# Commercial use of this code, and/or copying and redistribution is prohibited.
Expand Down Expand Up @@ -2684,12 +2694,12 @@ def rate(self) -> DualTypes | None:

def npv(
self,
curve: Curve | NoInput = NoInput(0),
disc_curve: Curve | NoInput = NoInput(0),
fx: float | FXRates | FXForwards | NoInput = NoInput(0),
base: str | NoInput = NoInput(0),
curve: CurveOption_ = NoInput(0),
disc_curve: Curve_ = NoInput(0),
fx: FX_ = NoInput(0),
base: str_ = NoInput(0),
local: bool = False,
) -> DualTypes | dict[str, DualTypes]:
) -> NPV:
"""
Return the NPV of the *Cashflow*.
See
Expand All @@ -2701,10 +2711,10 @@ def npv(

def cashflows(
self,
curve: Curve | NoInput = NoInput(0),
disc_curve: Curve | NoInput = NoInput(0),
fx: float | FXRates | FXForwards | NoInput = NoInput(0),
base: str | NoInput = NoInput(0),
curve: CurveOption_ = NoInput(0),
disc_curve: Curve_ = NoInput(0),
fx: FX_ = NoInput(0),
base: str_ = NoInput(0),
) -> dict[str, Any]:
"""
Return the cashflows of the *Cashflow*.
Expand Down

0 comments on commit cd62a1e

Please sign in to comment.