Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add node to CallType #2547

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions slither/core/cfg/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.utils.type_helpers import (
InternalCallType,
SolidityCallType,
HighLevelCallType,
LibraryCallType,
LowLevelCallType,
Expand Down Expand Up @@ -153,8 +154,8 @@ def __init__(
self._ssa_vars_written: List["SlithIRVariable"] = []
self._ssa_vars_read: List["SlithIRVariable"] = []

self._internal_calls: List[Union["Function", "SolidityFunction"]] = []
self._solidity_calls: List[SolidityFunction] = []
self._internal_calls: List["InternalCallType"] = []
self._solidity_calls: List["SolidityCallType"] = []
self._high_level_calls: List["HighLevelCallType"] = [] # contains library calls
self._library_calls: List["LibraryCallType"] = []
self._low_level_calls: List["LowLevelCallType"] = []
Expand Down Expand Up @@ -226,8 +227,9 @@ def type(self, new_type: NodeType) -> None:
@property
def will_return(self) -> bool:
if not self.sons and self.type != NodeType.THROW:
if SolidityFunction("revert()") not in self.solidity_calls:
if SolidityFunction("revert(string)") not in self.solidity_calls:
solidity_calls = [c for c, _ in self.solidity_calls]
if SolidityFunction("revert()") not in solidity_calls:
if SolidityFunction("revert(string)") not in solidity_calls:
return True
return False

Expand Down Expand Up @@ -375,14 +377,14 @@ def variables_written_as_expression(self, exprs: List[Expression]) -> None:
@property
def internal_calls(self) -> List["InternalCallType"]:
"""
list(Function or SolidityFunction): List of internal/soldiity function calls
list(Function or SolidityFunction): List of internal/solidity function calls
"""
return list(self._internal_calls)

@property
def solidity_calls(self) -> List[SolidityFunction]:
def solidity_calls(self) -> List["SolidityCallType"]:
"""
list(SolidityFunction): List of Soldity calls
list(SolidityFunction): List of Solidity calls
"""
return list(self._solidity_calls)

Expand Down Expand Up @@ -530,7 +532,7 @@ def contains_require_or_assert(self) -> bool:
"""
return any(
c.name in ["require(bool)", "require(bool,string)", "assert(bool)"]
for c in self.internal_calls
for c, _ in self.internal_calls
)

def contains_if(self, include_loop: bool = True) -> bool:
Expand Down Expand Up @@ -894,11 +896,11 @@ def _find_read_write_call(self) -> None: # pylint: disable=too-many-statements
self._vars_written.append(var)

if isinstance(ir, InternalCall):
self._internal_calls.append(ir.function)
self._internal_calls.append((ir.function, ir.node))
if isinstance(ir, SolidityCall):
# TODO: consider removing dependancy of solidity_call to internal_call
self._solidity_calls.append(ir.function)
self._internal_calls.append(ir.function)
self._solidity_calls.append((ir.function, ir.node))
self._internal_calls.append((ir.function, ir.node))
if (
isinstance(ir, SolidityCall)
and ir.function == SolidityFunction("sstore(uint256,uint256)")
Expand All @@ -916,22 +918,24 @@ def _find_read_write_call(self) -> None: # pylint: disable=too-many-statements
self._vars_read.append(ir.arguments[0])
if isinstance(ir, LowLevelCall):
assert isinstance(ir.destination, (Variable, SolidityVariable))
self._low_level_calls.append((ir.destination, str(ir.function_name.value)))
self._low_level_calls.append((ir.destination, str(ir.function_name.value), ir.node))
elif isinstance(ir, HighLevelCall) and not isinstance(ir, LibraryCall):
# Todo investigate this if condition
# It does seem right to compare against a contract
# This might need a refactoring
if isinstance(ir.destination.type, Contract):
self._high_level_calls.append((ir.destination.type, ir.function))
self._high_level_calls.append((ir.destination.type, ir.function, ir.node))
elif ir.destination == SolidityVariable("this"):
func = self.function
# Can't use this in a top level function
assert isinstance(func, FunctionContract)
self._high_level_calls.append((func.contract, ir.function))
self._high_level_calls.append((func.contract, ir.function, ir.node))
else:
try:
# Todo this part needs more tests and documentation
self._high_level_calls.append((ir.destination.type.type, ir.function))
self._high_level_calls.append(
(ir.destination.type.type, ir.function, ir.node)
)
except AttributeError as error:
# pylint: disable=raise-missing-from
raise SlitherException(
Expand All @@ -940,8 +944,8 @@ def _find_read_write_call(self) -> None: # pylint: disable=too-many-statements
elif isinstance(ir, LibraryCall):
assert isinstance(ir.destination, Contract)
assert isinstance(ir.function, Function)
self._high_level_calls.append((ir.destination, ir.function))
self._library_calls.append((ir.destination, ir.function))
self._high_level_calls.append((ir.destination, ir.function, ir.node))
self._library_calls.append((ir.destination, ir.function, ir.node))

self._vars_read = list(set(self._vars_read))
self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)]
Expand Down
8 changes: 6 additions & 2 deletions slither/core/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,15 +1023,19 @@ def get_functions_overridden_by(self, function: "Function") -> List["Function"]:
###################################################################################

@property
def all_functions_called(self) -> List["InternalCallType"]:
def all_functions_called(self) -> List[Function]:
"""
list(Function): List of functions reachable from the contract
Includes super, and private/internal functions not shadowed
"""
if self._all_functions_called is None:
all_functions = [f for f in self.functions + self.modifiers if not f.is_shadowed] # type: ignore
all_callss = [f.all_internal_calls() for f in all_functions] + [list(all_functions)]
all_calls = [item for sublist in all_callss for item in sublist]
all_calls = [
item[0] if isinstance(item, Tuple) else item
for sublist in all_callss
for item in sublist
]
all_calls = list(set(all_calls))

all_constructors = [c.constructor for c in self.inheritance if c.constructor]
Expand Down
15 changes: 8 additions & 7 deletions slither/core/declarations/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
if TYPE_CHECKING:
from slither.utils.type_helpers import (
InternalCallType,
SolidityCallType,
LowLevelCallType,
HighLevelCallType,
LibraryCallType,
Expand Down Expand Up @@ -150,7 +151,7 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit") -> None:
self._solidity_vars_read: List["SolidityVariable"] = []
self._state_vars_written: List["StateVariable"] = []
self._internal_calls: List["InternalCallType"] = []
self._solidity_calls: List["SolidityFunction"] = []
self._solidity_calls: List["SolidityCallType"] = []
self._low_level_calls: List["LowLevelCallType"] = []
self._high_level_calls: List["HighLevelCallType"] = []
self._library_calls: List["LibraryCallType"] = []
Expand All @@ -173,7 +174,7 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit") -> None:
self._all_high_level_calls: Optional[List["HighLevelCallType"]] = None
self._all_library_calls: Optional[List["LibraryCallType"]] = None
self._all_low_level_calls: Optional[List["LowLevelCallType"]] = None
self._all_solidity_calls: Optional[List["SolidityFunction"]] = None
self._all_solidity_calls: Optional[List["SolidityCallType"]] = None
self._all_variables_read: Optional[List["Variable"]] = None
self._all_variables_written: Optional[List["Variable"]] = None
self._all_state_variables_read: Optional[List["StateVariable"]] = None
Expand Down Expand Up @@ -864,7 +865,7 @@ def internal_calls(self) -> List["InternalCallType"]:
return list(self._internal_calls)

@property
def solidity_calls(self) -> List[SolidityFunction]:
def solidity_calls(self) -> List["SolidityCallType"]:
"""
list(SolidityFunction): List of Soldity calls
"""
Expand Down Expand Up @@ -1121,10 +1122,10 @@ def _explore_functions(self, f_new_values: Callable[["Function"], List]) -> List
values = f_new_values(self)
explored = [self]
to_explore = [
c for c in self.internal_calls if isinstance(c, Function) and c not in explored
c for c, _ in self.internal_calls if isinstance(c, Function) and c not in explored
]
to_explore += [
c for (_, c) in self.library_calls if isinstance(c, Function) and c not in explored
c for _, c, _ in self.library_calls if isinstance(c, Function) and c not in explored
]
to_explore += [m for m in self.modifiers if m not in explored]

Expand All @@ -1139,12 +1140,12 @@ def _explore_functions(self, f_new_values: Callable[["Function"], List]) -> List

to_explore += [
c
for c in f.internal_calls
for c, _ in f.internal_calls
if isinstance(c, Function) and c not in explored and c not in to_explore
]
to_explore += [
c
for (_, c) in f.library_calls
for _, c, _ in f.library_calls
if isinstance(c, Function) and c not in explored and c not in to_explore
]
to_explore += [m for m in f.modifiers if m not in explored and m not in to_explore]
Expand Down
2 changes: 1 addition & 1 deletion slither/core/declarations/function_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_summary(
[str(x) for x in self.modifiers],
[str(x) for x in self.state_variables_read + self.solidity_variables_read],
[str(x) for x in self.state_variables_written],
[str(x) for x in self.internal_calls],
[str(x) for x, _ in self.internal_calls],
[str(x) for x in self.external_calls_as_expressions],
compute_cyclomatic_complexity(self),
)
Expand Down
2 changes: 1 addition & 1 deletion slither/core/declarations/function_top_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_summary(
[str(x) for x in self.modifiers],
[str(x) for x in self.state_variables_read + self.solidity_variables_read],
[str(x) for x in self.state_variables_written],
[str(x) for x in self.internal_calls],
[str(x) for x, _ in self.internal_calls],
[str(x) for x in self.external_calls_as_expressions],
)

Expand Down
2 changes: 1 addition & 1 deletion slither/detectors/assembly/incorrect_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _detect(self) -> List[Output]:

for node in f.nodes:
if node.sons:
for function_called in node.internal_calls:
for function_called, _ in node.internal_calls:
if isinstance(function_called, Function):
found = _assembly_node(function_called)
if found:
Expand Down
2 changes: 1 addition & 1 deletion slither/detectors/attributes/locked_ether.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def do_no_send_ether(contract: Contract) -> bool:
explored += to_explore
to_explore = []
for function in functions:
calls = [c.name for c in function.internal_calls]
calls = [c.name for (c, _) in function.internal_calls]
if "suicide(address)" in calls or "selfdestruct(address)" in calls:
return False
for node in function.nodes:
Expand Down
6 changes: 4 additions & 2 deletions slither/detectors/functions/dead_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ def _detect(self) -> List[Output]:
all_functionss_called = [
f.all_internal_calls() for f in contract.functions_entry_points
]
all_functions_called = [item for sublist in all_functionss_called for item in sublist]
all_functions_called = [
item[0] for sublist in all_functionss_called for item in sublist
]
functions_used |= {
f.canonical_name for f in all_functions_called if isinstance(f, Function)
}
all_libss_called = [f.all_library_calls() for f in contract.functions_entry_points]
all_libs_called: List[Tuple[Contract, Function]] = [
item for sublist in all_libss_called for item in sublist
item[0] for sublist in all_libss_called for item in sublist
]
functions_used |= {
lib[1].canonical_name for lib in all_libs_called if isinstance(lib, tuple)
Expand Down
2 changes: 1 addition & 1 deletion slither/detectors/functions/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def is_revert(node: Node) -> bool:
return node.type == NodeType.THROW or any(
c.name in ["revert()", "revert(string"] for c in node.internal_calls
c.name in ["revert()", "revert(string"] for c, _ in node.internal_calls
)


Expand Down
2 changes: 1 addition & 1 deletion slither/detectors/functions/out_of_order_retryable.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _detect_multiple_tickets(

# include ops from internal function calls
internal_ops = []
for internal_call in node.internal_calls:
for internal_call, _ in node.internal_calls:
if isinstance(internal_call, Function):
internal_ops += internal_call.all_slithir_operations()

Expand Down
2 changes: 1 addition & 1 deletion slither/detectors/functions/protected_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _analyze_function(self, function: Function, contract: Contract) -> List[Outp
if not function_protection:
self.logger.error(f"{function_sig} not found")
continue
if function_protection not in function.all_internal_calls():
if function_protection not in [f for f, _ in function.all_internal_calls()]:
info: DETECTOR_INFO = [
function,
" should have ",
Expand Down
2 changes: 1 addition & 1 deletion slither/detectors/functions/suicidal.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def detect_suicidal_func(func: FunctionContract) -> bool:
if func.visibility not in ["public", "external"]:
return False

calls = [c.name for c in func.all_internal_calls()]
calls = [c.name for c, _ in func.all_internal_calls()]
if not ("suicide(address)" in calls or "selfdestruct(address)" in calls):
return False

Expand Down
2 changes: 1 addition & 1 deletion slither/detectors/reentrancy/reentrancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def analyze_node(self, node: Node, detector: "Reentrancy") -> bool:
)
slithir_operations = []
# Add the state variables written in internal calls
for internal_call in node.internal_calls:
for internal_call, _ in node.internal_calls:
# Filter to Function, as internal_call can be a solidity call
if isinstance(internal_call, Function):
for internal_node in internal_call.all_nodes():
Expand Down
2 changes: 1 addition & 1 deletion slither/detectors/statements/assert_state_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def detect_assert_state_change(
for function in contract.functions_declared + list(contract.modifiers_declared):
for node in function.nodes:
# Detect assert() calls
if any(c.name == "assert(bool)" for c in node.internal_calls) and (
if any(c.name == "assert(bool)" for c, _ in node.internal_calls) and (
# Detect direct changes to state
node.state_variables_written
or
Expand Down
2 changes: 1 addition & 1 deletion slither/detectors/statements/divide_before_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def is_assert(node: Node) -> bool:
# Old Solidity code where using an internal 'assert(bool)' function
# While we dont check that this function is correct, we assume it is
# To avoid too many FP
if "assert(bool)" in [c.full_name for c in node.internal_calls]:
if "assert(bool)" in [c.full_name for c, _ in node.internal_calls]:
return True
return False

Expand Down
2 changes: 1 addition & 1 deletion slither/detectors/statements/unprotected_upgradeable.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _has_initializing_protection(functions: List[Function]) -> bool:
for m in f.modifiers:
if m.name == "initializer":
return True
for ifc in f.all_internal_calls():
for ifc, _ in f.all_internal_calls():
if ifc.name == "_disableInitializers":
return True

Expand Down
9 changes: 5 additions & 4 deletions slither/printers/call/call_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
The output is a dot file named filename.dot
"""
from collections import defaultdict
from typing import Optional, Union, Dict, Set, Tuple, Sequence
from typing import Optional, Union, Dict, Set, Sequence

from slither.core.declarations import Contract, FunctionContract
from slither.core.declarations.function import Function
from slither.core.declarations.solidity_variables import SolidityFunction
from slither.core.variables.variable import Variable
from slither.printers.abstract_printer import AbstractPrinter
from slither.utils.output import Output
from slither.utils.type_helpers import HighLevelCallType


def _contract_subgraph(contract: Contract) -> str:
Expand Down Expand Up @@ -112,12 +113,12 @@ def _render_solidity_calls(solidity_functions: Set[str], solidity_calls: Set[str
def _process_external_call(
contract: Contract,
function: Function,
external_call: Tuple[Contract, Union[Function, Variable]],
external_call: HighLevelCallType,
contract_functions: Dict[Contract, Set[str]],
external_calls: Set[str],
all_contracts: Set[Contract],
) -> None:
external_contract, external_function = external_call
external_contract, external_function, _ = external_call

if not external_contract in all_contracts:
return
Expand Down Expand Up @@ -154,7 +155,7 @@ def _process_function(
_node(_function_node(contract, function), function.name),
)

for internal_call in function.internal_calls:
for internal_call, _ in function.internal_calls:
_process_internal_call(
contract,
function,
Expand Down
2 changes: 1 addition & 1 deletion slither/printers/functions/authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class PrinterWrittenVariablesAndAuthorization(AbstractPrinter):
@staticmethod
def get_msg_sender_checks(function: Function) -> List[str]:
all_functions = (
[f for f in function.all_internal_calls() if isinstance(f, Function)]
[f for f, _ in function.all_internal_calls() if isinstance(f, Function)]
+ [function]
+ [m for m in function.modifiers if isinstance(m, Function)]
)
Expand Down
4 changes: 2 additions & 2 deletions slither/printers/summary/modifier_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def output(self, _filename):
table = MyPrettyTable(["Function", "Modifiers"])
for function in contract.functions:
modifiers = function.modifiers
for call in function.all_internal_calls():
for call, _ in function.all_internal_calls():
if isinstance(call, Function):
modifiers += call.modifiers
for (_, call) in function.all_library_calls():
for _, call, _ in function.all_library_calls():
if isinstance(call, Function):
modifiers += call.modifiers
table.add_row([function.name, sorted([m.name for m in set(modifiers)])])
Expand Down
2 changes: 1 addition & 1 deletion slither/printers/summary/when_not_paused.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def _use_modifier(function: Function, modifier_name: str = "whenNotPaused") -> bool:

for internal_call in function.all_internal_calls():
for internal_call, _ in function.all_internal_calls():
if isinstance(internal_call, SolidityFunction):
continue
if any(modifier.name == modifier_name for modifier in function.modifiers):
Expand Down
Loading
Loading