From 0d3243abe1692fe814bcaebcb7115b3ad33bb5a4 Mon Sep 17 00:00:00 2001 From: Ben Zickel <35469979+BenZickel@users.noreply.github.com> Date: Tue, 24 Sep 2024 14:38:36 +0300 Subject: [PATCH] Fix EqualizeMessenger type annotations (#3401) * Fix EqualizeMessenger type annotations. * Change keep_dist typing to Optional[bool]. --------- Co-authored-by: Ben Zickel --- pyro/poutine/equalize_messenger.py | 2 +- pyro/poutine/handlers.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pyro/poutine/equalize_messenger.py b/pyro/poutine/equalize_messenger.py index 1bc79a5521..31e524a651 100644 --- a/pyro/poutine/equalize_messenger.py +++ b/pyro/poutine/equalize_messenger.py @@ -68,7 +68,7 @@ def __init__( self, sites: Union[str, List[str]], type: Optional[str] = "sample", - keep_dist: bool = False, + keep_dist: Optional[bool] = False, ) -> None: super().__init__() self.sites = [sites] if isinstance(sites, str) else sites diff --git a/pyro/poutine/handlers.py b/pyro/poutine/handlers.py index 343b1a1f4b..bc1ba91de8 100644 --- a/pyro/poutine/handlers.py +++ b/pyro/poutine/handlers.py @@ -306,7 +306,8 @@ def escape( # type: ignore[empty-body] def equalize( sites: Union[str, List[str]], type: Optional[str], -) -> ConditionMessenger: ... + keep_dist: Optional[bool], +) -> EqualizeMessenger: ... @overload @@ -314,6 +315,7 @@ def equalize( fn: Callable[_P, _T], sites: Union[str, List[str]], type: Optional[str], + keep_dist: Optional[bool], ) -> Callable[_P, _T]: ... @@ -322,6 +324,7 @@ def equalize( # type: ignore[empty-body] fn: Callable[_P, _T], sites: Union[str, List[str]], type: Optional[str], + keep_dist: Optional[bool], ) -> Union[EqualizeMessenger, Callable[_P, _T]]: ...