From 8d4126b3adda3c59abd003c17da7f1135d862c33 Mon Sep 17 00:00:00 2001 From: "JHM Darbyshire (win11)" Date: Mon, 13 Jan 2025 12:22:33 +0100 Subject: [PATCH] TYP: multi ccy derivs --- .../instruments/rates/multi_currency.py | 59 ++++++++++--------- python/rateslib/periods.py | 30 ++++++---- 2 files changed, 51 insertions(+), 38 deletions(-) diff --git a/python/rateslib/instruments/rates/multi_currency.py b/python/rateslib/instruments/rates/multi_currency.py index 0ed2d6a7..a9712225 100644 --- a/python/rateslib/instruments/rates/multi_currency.py +++ b/python/rateslib/instruments/rates/multi_currency.py @@ -125,7 +125,7 @@ def npv( """ self._set_pricing_mid(curves, solver, fx) - curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver( + curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver( self.curves, solver, curves, @@ -139,10 +139,10 @@ def npv( "Must have some FX information to price FXExchange, either `fx` or " "`solver` containing an FX object.", ) - if not isinstance(fx_, FXRates | FXForwards): + elif not isinstance(fx_, FXRates | FXForwards): # force base_ leg1 currency to be converted consistent. - leg1_npv = self.leg1.npv(curves[0], curves[1], fx_, base_, local) - leg2_npv = self.leg2.npv(curves[2], curves[3], 1.0, base_, local) + leg1_npv = self.leg1.npv(curves_[0], curves_[1], fx_, base_, local) + leg2_npv = self.leg2.npv(curves_[2], curves_[3], 1.0, base_, local) warnings.warn( "When valuing multi-currency derivatives it not best practice to " "supply `fx` as numeric.\nYour input:\n" @@ -155,8 +155,8 @@ def npv( UserWarning, ) else: - leg1_npv = self.leg1.npv(curves[0], curves[1], fx_, base_, local) - leg2_npv = self.leg2.npv(curves[2], curves[3], fx_, base_, local) + leg1_npv = self.leg1.npv(curves_[0], curves_[1], fx_, base_, local) + leg2_npv = self.leg2.npv(curves_[2], curves_[3], fx_, base_, local) if local: return { @@ -178,7 +178,7 @@ def cashflows( For arguments see :meth:`BaseMixin.npv` """ self._set_pricing_mid(curves, solver, fx) - curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver( + curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver( self.curves, solver, curves, @@ -187,8 +187,8 @@ def cashflows( NoInput(0), ) seq = [ - self.leg1.cashflows(curves[0], curves[1], fx_, base_), - self.leg2.cashflows(curves[2], curves[3], fx_, base_), + self.leg1.cashflows(curves_[0], curves_[1], fx_, base_), + self.leg2.cashflows(curves_[2], curves_[3], fx_, base_), ] _ = DataFrame.from_records(seq) _.index = MultiIndex.from_tuples([("leg1", 0), ("leg2", 0)]) @@ -206,7 +206,7 @@ def rate( For arguments see :meth:`BaseMixin.rate` """ - curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver( + curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver( self.curves, solver, curves, @@ -215,7 +215,7 @@ def rate( self.leg1.currency, ) if isinstance(fx_, FXRates | FXForwards): - imm_fx = fx_.rate(self.pair) + imm_fx: FX_ = fx_.rate(self.pair) else: imm_fx = fx_ @@ -224,7 +224,7 @@ def rate( "`fx` must be supplied to price FXExchange object.\n" "Note: it can be attached to and then gotten from a Solver.", ) - _ = forward_fx(self.settlement, curves[1], curves[3], imm_fx) + _ = forward_fx(self.settlement, curves_[1], curves_[3], imm_fx) return _ def delta(self, *args: Any, **kwargs: Any) -> DataFrame: @@ -328,6 +328,9 @@ class XCS(BaseDerivative): Required keyword arguments for :class:`~rateslib.instruments.BaseDerivative`. """ + leg1: FixedLeg | FloatLeg + leg2: FixedLeg | FloatLeg | FloatLegMtm | FixedLegMtm + def __init__( self, *args: Any, @@ -734,7 +737,7 @@ def rate( return _ if _is_float_tgt_leg else _ * 0.01 - def spread(self, *args, **kwargs): + def spread(self, *args: Any, **kwargs: Any) -> DualTypes: """ Alias for :meth:`~rateslib.instruments.BaseXCS.rate` """ @@ -747,7 +750,7 @@ def cashflows( fx: FXForwards | NoInput = NoInput(0), base: str_ = NoInput(0), ): - curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver( + curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver( self.curves, solver, curves, @@ -757,13 +760,13 @@ def cashflows( ) if self._is_unpriced: - self._set_pricing_mid(curves, solver, fx_) + self._set_pricing_mid(curves_, solver, fx_) self._set_fx_fixings(fx_) if self._is_mtm: self.leg2._do_not_repeat_set_periods = True - ret = super().cashflows(curves, solver, fx_, base_) + ret = super().cashflows(curves_, solver, fx_, base_) if self._is_mtm: self.leg2._do_not_repeat_set_periods = False # reset the mtm calc return ret @@ -810,7 +813,7 @@ def fixings_table( ------- DataFrame """ - curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver( + curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver( self.curves, solver, curves, @@ -821,8 +824,8 @@ def fixings_table( try: df1 = self.leg1.fixings_table( - curve=curves[0], - disc_curve=curves[1], + curve=curves_[0], + disc_curve=curves_[1], fx=fx_, base=base_, approximate=approximate, @@ -835,8 +838,8 @@ def fixings_table( try: df2 = self.leg2.fixings_table( - curve=curves[2], - disc_curve=curves[3], + curve=curves_[2], + disc_curve=curves_[3], fx=fx_, base=base_, approximate=approximate, @@ -1146,7 +1149,7 @@ def _set_pricing_mid( curves: Curves_ = NoInput(0), solver: Solver_ = NoInput(0), fx: FXForwards | NoInput = NoInput(0), - ): + ) -> None: # This function ASSUMES that the instrument is unpriced, i.e. all of # split_notional, fx_fixing and points have been initialised as None. @@ -1161,7 +1164,7 @@ def rate( solver: Solver_ = NoInput(0), fx: FXForwards | NoInput = NoInput(0), fixed_rate: bool = False, - ): + ) -> DualTypes: """ Return the mid-market pricing parameter of the FXSwapS. @@ -1188,7 +1191,7 @@ def rate( ------- float, Dual or Dual2 """ - curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver( + curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver( self.curves, solver, curves, @@ -1197,11 +1200,11 @@ def rate( self.leg1.currency, ) # set the split notional from the curve if not available - self._set_split_notional(curve=curves[1]) + self._set_split_notional(curve=curves_[1]) # then we will set the fx_fixing and leg2 initial notional. # self._set_fx_fixings(fx) # this will be done by super().rate() - leg2_fixed_rate = super().rate(curves, solver, fx_, leg=2) + leg2_fixed_rate = super().rate(curves_, solver, fx_, leg=2) if fixed_rate: return leg2_fixed_rate @@ -1218,8 +1221,8 @@ def cashflows( solver: Solver_ = NoInput(0), fx: FXForwards | NoInput = NoInput(0), base: str_ = NoInput(0), - ): + ) -> DataFrame: if self._is_unpriced: self._set_pricing_mid(curves, solver, fx) - ret = super().cashflows(curves, solver, fx, base) + ret: DataFrame = super().cashflows(curves, solver, fx, base) return ret diff --git a/python/rateslib/periods.py b/python/rateslib/periods.py index c2dce30b..08a8ba80 100644 --- a/python/rateslib/periods.py +++ b/python/rateslib/periods.py @@ -61,7 +61,17 @@ from rateslib.splines import evaluate if TYPE_CHECKING: - from rateslib.typing import FX_, CalInput, CalTypes, Curve_, CurveOption_, DualTypes, Number + from rateslib.typing import ( + FX_, + NPV, + CalInput, + CalTypes, + Curve_, + CurveOption_, + DualTypes, + Number, + str_, + ) # Licence: Creative Commons - Attribution-NonCommercial-NoDerivatives 4.0 International # Commercial use of this code, and/or copying and redistribution is prohibited. @@ -2684,12 +2694,12 @@ def rate(self) -> DualTypes | None: def npv( self, - curve: Curve | NoInput = NoInput(0), - disc_curve: Curve | NoInput = NoInput(0), - fx: float | FXRates | FXForwards | NoInput = NoInput(0), - base: str | NoInput = NoInput(0), + curve: CurveOption_ = NoInput(0), + disc_curve: Curve_ = NoInput(0), + fx: FX_ = NoInput(0), + base: str_ = NoInput(0), local: bool = False, - ) -> DualTypes | dict[str, DualTypes]: + ) -> NPV: """ Return the NPV of the *Cashflow*. See @@ -2701,10 +2711,10 @@ def npv( def cashflows( self, - curve: Curve | NoInput = NoInput(0), - disc_curve: Curve | NoInput = NoInput(0), - fx: float | FXRates | FXForwards | NoInput = NoInput(0), - base: str | NoInput = NoInput(0), + curve: CurveOption_ = NoInput(0), + disc_curve: Curve_ = NoInput(0), + fx: FX_ = NoInput(0), + base: str_ = NoInput(0), ) -> dict[str, Any]: """ Return the cashflows of the *Cashflow*.