Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP: multi ccy derivs #625

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 31 additions & 28 deletions python/rateslib/instruments/rates/multi_currency.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def npv(
"""
self._set_pricing_mid(curves, solver, fx)

curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
self.curves,
solver,
curves,
Expand All @@ -139,10 +139,10 @@ def npv(
"Must have some FX information to price FXExchange, either `fx` or "
"`solver` containing an FX object.",
)
if not isinstance(fx_, FXRates | FXForwards):
elif not isinstance(fx_, FXRates | FXForwards):
# force base_ leg1 currency to be converted consistent.
leg1_npv = self.leg1.npv(curves[0], curves[1], fx_, base_, local)
leg2_npv = self.leg2.npv(curves[2], curves[3], 1.0, base_, local)
leg1_npv = self.leg1.npv(curves_[0], curves_[1], fx_, base_, local)
leg2_npv = self.leg2.npv(curves_[2], curves_[3], 1.0, base_, local)
warnings.warn(
"When valuing multi-currency derivatives it not best practice to "
"supply `fx` as numeric.\nYour input:\n"
Expand All @@ -155,8 +155,8 @@ def npv(
UserWarning,
)
else:
leg1_npv = self.leg1.npv(curves[0], curves[1], fx_, base_, local)
leg2_npv = self.leg2.npv(curves[2], curves[3], fx_, base_, local)
leg1_npv = self.leg1.npv(curves_[0], curves_[1], fx_, base_, local)
leg2_npv = self.leg2.npv(curves_[2], curves_[3], fx_, base_, local)

if local:
return {
Expand All @@ -178,7 +178,7 @@ def cashflows(
For arguments see :meth:`BaseMixin.npv<rateslib.instruments.BaseMixin.cashflows>`
"""
self._set_pricing_mid(curves, solver, fx)
curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
self.curves,
solver,
curves,
Expand All @@ -187,8 +187,8 @@ def cashflows(
NoInput(0),
)
seq = [
self.leg1.cashflows(curves[0], curves[1], fx_, base_),
self.leg2.cashflows(curves[2], curves[3], fx_, base_),
self.leg1.cashflows(curves_[0], curves_[1], fx_, base_),
self.leg2.cashflows(curves_[2], curves_[3], fx_, base_),
]
_ = DataFrame.from_records(seq)
_.index = MultiIndex.from_tuples([("leg1", 0), ("leg2", 0)])
Expand All @@ -206,7 +206,7 @@ def rate(

For arguments see :meth:`BaseMixin.rate<rateslib.instruments.BaseMixin.rate>`
"""
curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
self.curves,
solver,
curves,
Expand All @@ -215,7 +215,7 @@ def rate(
self.leg1.currency,
)
if isinstance(fx_, FXRates | FXForwards):
imm_fx = fx_.rate(self.pair)
imm_fx: FX_ = fx_.rate(self.pair)
else:
imm_fx = fx_

Expand All @@ -224,7 +224,7 @@ def rate(
"`fx` must be supplied to price FXExchange object.\n"
"Note: it can be attached to and then gotten from a Solver.",
)
_ = forward_fx(self.settlement, curves[1], curves[3], imm_fx)
_ = forward_fx(self.settlement, curves_[1], curves_[3], imm_fx)
return _

def delta(self, *args: Any, **kwargs: Any) -> DataFrame:
Expand Down Expand Up @@ -328,6 +328,9 @@ class XCS(BaseDerivative):
Required keyword arguments for :class:`~rateslib.instruments.BaseDerivative`.
"""

leg1: FixedLeg | FloatLeg
leg2: FixedLeg | FloatLeg | FloatLegMtm | FixedLegMtm

def __init__(
self,
*args: Any,
Expand Down Expand Up @@ -734,7 +737,7 @@ def rate(

return _ if _is_float_tgt_leg else _ * 0.01

def spread(self, *args, **kwargs):
def spread(self, *args: Any, **kwargs: Any) -> DualTypes:
"""
Alias for :meth:`~rateslib.instruments.BaseXCS.rate`
"""
Expand All @@ -747,7 +750,7 @@ def cashflows(
fx: FXForwards | NoInput = NoInput(0),
base: str_ = NoInput(0),
):
curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
self.curves,
solver,
curves,
Expand All @@ -757,13 +760,13 @@ def cashflows(
)

if self._is_unpriced:
self._set_pricing_mid(curves, solver, fx_)
self._set_pricing_mid(curves_, solver, fx_)

self._set_fx_fixings(fx_)
if self._is_mtm:
self.leg2._do_not_repeat_set_periods = True

ret = super().cashflows(curves, solver, fx_, base_)
ret = super().cashflows(curves_, solver, fx_, base_)
if self._is_mtm:
self.leg2._do_not_repeat_set_periods = False # reset the mtm calc
return ret
Expand Down Expand Up @@ -810,7 +813,7 @@ def fixings_table(
-------
DataFrame
"""
curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
self.curves,
solver,
curves,
Expand All @@ -821,8 +824,8 @@ def fixings_table(

try:
df1 = self.leg1.fixings_table(
curve=curves[0],
disc_curve=curves[1],
curve=curves_[0],
disc_curve=curves_[1],
fx=fx_,
base=base_,
approximate=approximate,
Expand All @@ -835,8 +838,8 @@ def fixings_table(

try:
df2 = self.leg2.fixings_table(
curve=curves[2],
disc_curve=curves[3],
curve=curves_[2],
disc_curve=curves_[3],
fx=fx_,
base=base_,
approximate=approximate,
Expand Down Expand Up @@ -1146,7 +1149,7 @@ def _set_pricing_mid(
curves: Curves_ = NoInput(0),
solver: Solver_ = NoInput(0),
fx: FXForwards | NoInput = NoInput(0),
):
) -> None:
# This function ASSUMES that the instrument is unpriced, i.e. all of
# split_notional, fx_fixing and points have been initialised as None.

Expand All @@ -1161,7 +1164,7 @@ def rate(
solver: Solver_ = NoInput(0),
fx: FXForwards | NoInput = NoInput(0),
fixed_rate: bool = False,
):
) -> DualTypes:
"""
Return the mid-market pricing parameter of the FXSwapS.

Expand All @@ -1188,7 +1191,7 @@ def rate(
-------
float, Dual or Dual2
"""
curves, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
curves_, fx_, base_ = _get_curves_fx_and_base_maybe_from_solver(
self.curves,
solver,
curves,
Expand All @@ -1197,11 +1200,11 @@ def rate(
self.leg1.currency,
)
# set the split notional from the curve if not available
self._set_split_notional(curve=curves[1])
self._set_split_notional(curve=curves_[1])
# then we will set the fx_fixing and leg2 initial notional.

# self._set_fx_fixings(fx) # this will be done by super().rate()
leg2_fixed_rate = super().rate(curves, solver, fx_, leg=2)
leg2_fixed_rate = super().rate(curves_, solver, fx_, leg=2)

if fixed_rate:
return leg2_fixed_rate
Expand All @@ -1218,8 +1221,8 @@ def cashflows(
solver: Solver_ = NoInput(0),
fx: FXForwards | NoInput = NoInput(0),
base: str_ = NoInput(0),
):
) -> DataFrame:
if self._is_unpriced:
self._set_pricing_mid(curves, solver, fx)
ret = super().cashflows(curves, solver, fx, base)
ret: DataFrame = super().cashflows(curves, solver, fx, base)
return ret
30 changes: 20 additions & 10 deletions python/rateslib/periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,17 @@
from rateslib.splines import evaluate

if TYPE_CHECKING:
from rateslib.typing import FX_, CalInput, CalTypes, Curve_, CurveOption_, DualTypes, Number
from rateslib.typing import (
FX_,
NPV,
CalInput,
CalTypes,
Curve_,
CurveOption_,
DualTypes,
Number,
str_,
)

# Licence: Creative Commons - Attribution-NonCommercial-NoDerivatives 4.0 International
# Commercial use of this code, and/or copying and redistribution is prohibited.
Expand Down Expand Up @@ -2684,12 +2694,12 @@ def rate(self) -> DualTypes | None:

def npv(
self,
curve: Curve | NoInput = NoInput(0),
disc_curve: Curve | NoInput = NoInput(0),
fx: float | FXRates | FXForwards | NoInput = NoInput(0),
base: str | NoInput = NoInput(0),
curve: CurveOption_ = NoInput(0),
disc_curve: Curve_ = NoInput(0),
fx: FX_ = NoInput(0),
base: str_ = NoInput(0),
local: bool = False,
) -> DualTypes | dict[str, DualTypes]:
) -> NPV:
"""
Return the NPV of the *Cashflow*.
See
Expand All @@ -2701,10 +2711,10 @@ def npv(

def cashflows(
self,
curve: Curve | NoInput = NoInput(0),
disc_curve: Curve | NoInput = NoInput(0),
fx: float | FXRates | FXForwards | NoInput = NoInput(0),
base: str | NoInput = NoInput(0),
curve: CurveOption_ = NoInput(0),
disc_curve: Curve_ = NoInput(0),
fx: FX_ = NoInput(0),
base: str_ = NoInput(0),
) -> dict[str, Any]:
"""
Return the cashflows of the *Cashflow*.
Expand Down
Loading