Skip to content

Commit

Permalink
fix: satisfy linter by using optional types
Browse files Browse the repository at this point in the history
  • Loading branch information
fiskrt committed May 23, 2024
1 parent 238125a commit dfed5d8
Show file tree
Hide file tree
Showing 11 changed files with 24 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,7 @@ def validate_output(self, sequences: List[Any]) -> Tuple[List[Any], List[int]]:
if isinstance(self.tokenizer.text_tokenizer, PolymerGraphTokenizer):
# Copolymer models require specific validation
return validate_molecules(
pattern_list=list(zip(*sequences))[0],
pattern_list=list(zip(*sequences))[0], # type: ignore
input_type=MoleculeFormat.copolymer,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __init__(
targets: Dict[str, float],
property_predictors: Dict[str, PropertyPredictor],
representations: RepresentationsDict,
representation_order: List[str] = None,
representation_order: List[str] | None = None,
scalers: Optional[Dict[str, Scaler]] = None,
weights: Optional[Dict[str, float]] = None,
custom_score_function: Optional[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ def _forward(
bond_index: torch.Tensor,
bond_type: torch.Tensor,
batch: torch.Tensor,
edge_index: torch.Tensor = None,
edge_type: torch.Tensor = None,
edge_length: int = None,
edge_index: torch.Tensor | None = None,
edge_type: torch.Tensor | None = None,
edge_length: int | None = None,
return_edges: bool = False,
extend_order: bool = True,
extend_radius: bool = True,
is_sidechain: bool = None,
is_sidechain: bool | None = None,
) -> Tuple[Any, ...]:
"""Forward pass for edges features.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def set_seed(seed: int = 42) -> None:
"""
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available:
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) # type:ignore


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def set_seed(seed: int = 42) -> None:
"""
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available:
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) # type:ignore


Expand Down
4 changes: 2 additions & 2 deletions src/gt4sd/frameworks/cgcnn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class GaussianDistance:
Unit: angstrom
"""

def __init__(self, dmin: float, dmax: float, step: float, var: float = None):
def __init__(self, dmin: float, dmax: float, step: float, var: float | None = None):
"""
Args:
dmin: float
Expand Down Expand Up @@ -331,7 +331,7 @@ def __init__(
dmin: int = 0,
step: float = 0.2,
random_seed: int = 123,
atom_initialization: AtomCustomJSONInitializer = None,
atom_initialization: AtomCustomJSONInitializer | None = None,
):
"""
Args:
Expand Down
10 changes: 5 additions & 5 deletions src/gt4sd/frameworks/gflownet/dataloader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class GFlowNetDataset(Dataset):
"""A dataset for gflownet."""

def __init__(
self, h5_file: str = None, target: str = "gap", properties: List[str] = []
self, h5_file: str | None = None, target: str = "gap", properties: List[str] = []
) -> None:

"""Initialize a gflownet dataset.
Expand Down Expand Up @@ -239,8 +239,8 @@ def __init__(
self,
configuration: Dict[str, Any],
dataset: GFlowNetDataset,
reward_model: nn.Module = None,
wrap_model: Callable[[nn.Module], nn.Module] = None,
reward_model: nn.Module | None = None,
wrap_model: Callable[[nn.Module], nn.Module] | None = None,
) -> None:

"""Initialize a generic gflownet task.
Expand Down Expand Up @@ -282,7 +282,7 @@ def load_task_models(self) -> Dict[str, nn.Module]:
Returns:
model: a dictionary with the task models.
"""
pass
raise NotImplementedError()

def sample_conditional_information(self, n: int) -> Dict[str, Any]:
"""Samples conditional information for a minibatch.
Expand All @@ -293,7 +293,7 @@ def sample_conditional_information(self, n: int) -> Dict[str, Any]:
Returns:
cond_info: a dictionary with the sampled conditional information.
"""
pass
raise NotImplementedError()

def cond_info_to_reward(
self, cond_info: Dict[str, Any], flat_reward: FlatRewards
Expand Down
10 changes: 5 additions & 5 deletions src/gt4sd/frameworks/gflownet/envs/graph_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ class GraphAction:
def __init__(
self,
action: GraphActionType,
source: int = None,
target: int = None,
source: int | None = None,
target: int | None = None,
value: Any = None,
attr: str = None,
relabel: int = None,
attr: str | None = None,
relabel: int | None = None,
):
"""Initialize a single graph-building action.
Expand Down Expand Up @@ -287,7 +287,7 @@ def count_backward_transitions(self, g: Graph) -> int:


def generate_forward_trajectory(
g: Graph, max_nodes: int = None
g: Graph, max_nodes: int | None = None
) -> List[Tuple[Graph, GraphAction]]:
"""Sample (uniformly) a trajectory that generates g.
Expand Down
2 changes: 1 addition & 1 deletion src/gt4sd/frameworks/gflownet/loss/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
configuration: Dict[str, Any],
environment: GraphBuildingEnv,
context: GraphBuildingEnvContext,
max_len: int = None,
max_len: int | None = None,
):
"""Initialize trajectory balance algorithm.
Expand Down
4 changes: 2 additions & 2 deletions src/gt4sd/frameworks/gflownet/tests/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def __init__(
self,
configuration: Dict[str, Any],
dataset: GFlowNetDataset,
reward_model: nn.Module = None,
wrap_model: Callable[[nn.Module], nn.Module] = None,
reward_model: nn.Module | None = None,
wrap_model: Callable[[nn.Module], nn.Module] | None = None,
):
"""Initialize QM9 task.
Expand Down
2 changes: 1 addition & 1 deletion src/gt4sd/properties/scores/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class CombinedScorer:
def __init__(
self,
scorer_list: List[Type[Any]],
weights: List[float] = None,
weights: List[float] | None = None,
) -> None:
"""Scoring function which generates a combined score for a SMILES as per the given scoring functions.
Expand Down

0 comments on commit dfed5d8

Please sign in to comment.