Skip to content

Commit

Permalink
Fix pyre-fixmes in captum/attr/_utils/input_layer_wrapper.py (#1453)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1453

Fix Pyre fixmes in the input layer wrapper python file.

Reviewed By: craymichael

Differential Revision: D67032036

fbshipit-source-id: d93a28588a630bfda408f3a61f8b96553064706d
  • Loading branch information
styusuf authored and facebook-github-bot committed Dec 11, 2024
1 parent ad77e79 commit 92d82df
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions captum/attr/_utils/input_layer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
# pyre-strict

import inspect
from typing import Any
from typing import List

import torch.nn as nn
from torch import Tensor


class InputIdentity(nn.Module):
Expand All @@ -21,9 +22,7 @@ def __init__(self, input_name: str) -> None:
super().__init__()
self.input_name = input_name

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
return x


Expand Down Expand Up @@ -64,20 +63,19 @@ def __init__(self, module_to_wrap: nn.Module) -> None:
self.module = module_to_wrap

# ignore self
# pyre-fixme[4]: Attribute must be annotated.
self.arg_name_list = inspect.getfullargspec(module_to_wrap.forward).args[1:]
self.arg_name_list: List[str] = inspect.getfullargspec(
module_to_wrap.forward
).args[1:]
self.input_maps = nn.ModuleDict(
{arg_name: InputIdentity(arg_name) for arg_name in self.arg_name_list}
)

# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[2]: Parameter must be annotated.
def forward(self, *args, **kwargs) -> Any:
args = list(args)
for idx, (arg_name, arg) in enumerate(zip(self.arg_name_list, args)):
args[idx] = self.input_maps[arg_name](arg)
def forward(self, *args: object, **kwargs: object) -> object:
args_list = list(args)
for idx, (arg_name, arg) in enumerate(zip(self.arg_name_list, args_list)):
args_list[idx] = self.input_maps[arg_name](arg)

for arg_name in kwargs.keys():
kwargs[arg_name] = self.input_maps[arg_name](kwargs[arg_name])

return self.module(*tuple(args), **kwargs)
return self.module(*tuple(args_list), **kwargs)

0 comments on commit 92d82df

Please sign in to comment.