diff --git a/.gitignore b/.gitignore index df191195..0ec272f6 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ vulture.egg-info/ .pytest_cache/ .tox/ .venv/ +.vscode/ diff --git a/tests/test_scavenging.py b/tests/test_scavenging.py index 27154cd6..6a287547 100644 --- a/tests/test_scavenging.py +++ b/tests/test_scavenging.py @@ -899,3 +899,38 @@ class Color(Enum): check(v.unused_classes, []) check(v.unused_vars, ["BLUE"]) + + +def test_enum_list(v): + v.scan( + """\ +import enum +class E(enum.Enum): + A = 1 + B = 2 + +print(list(E)) +""" + ) + + check(v.defined_classes, ["E"]) + check(v.defined_vars, ["A", "B"]) + check(v.unused_vars, []) + + +def test_enum_for(v): + v.scan( + """\ +import enum +class E(enum.Enum): + A = 1 + B = 2 + +for e in E: + print(e) +""" + ) + + check(v.defined_classes, ["E"]) + check(v.defined_vars, ["A", "B", "e"]) + check(v.unused_vars, []) diff --git a/vulture/core.py b/vulture/core.py index 16c71948..145fa6c6 100644 --- a/vulture/core.py +++ b/vulture/core.py @@ -219,6 +219,8 @@ def get_list(typ): self.code = [] self.found_dead_code_or_error = False + self.enum_class_vars = dict() # stores variables defined in enum classes + def scan(self, code, filename=""): filename = Path(filename) self.code = code.splitlines() @@ -551,6 +553,18 @@ def visit_Call(self, node): ): self._handle_new_format_string(node.func.value.s) + # handle enum.Enum members + iter_functions = ["list", "tuple", "set"] + if ( + isinstance(node.func, ast.Name) + and node.func.id in iter_functions + and len(node.args) > 0 + and isinstance(node.args[0], ast.Name) + ): + arg = node.args[0].id + if arg in self.enum_class_vars: + self.used_names.update(self.enum_class_vars[arg]) + def _handle_new_format_string(self, s): def is_identifier(name): return bool(re.match(r"[a-zA-Z_][a-zA-Z0-9_]*", name)) @@ -594,6 +608,28 @@ def visit_ClassDef(self, node): self._define( self.defined_classes, node.name, node, ignore=_ignore_class ) + # if subclasses enum add class variables to enum_class_vars + if self._subclassesEnum(node): + newKey = node.name + classVariables = [] + for stmt in node.body: + if isinstance(stmt, ast.Assign): + for target in stmt.targets: + classVariables.append(target.id) + self.enum_class_vars[newKey] = classVariables + + def _subclassesEnum(self, node): + ''' + Checks if a class has Enum as a superclass + ''' + for base in node.bases: + if isinstance(base, ast.Name): + if base.id.lower() == "enum": + return True + elif isinstance(base, ast.Attribute): + if base.value.id.lower() == "enum": + return True + return False def visit_FunctionDef(self, node): decorator_names = [ @@ -661,6 +697,14 @@ def visit_Assign(self, node): def visit_While(self, node): self._handle_conditional_node(node, "while") + def visit_For(self, node): + # Handle iterating over Enum + if ( + isinstance(node.iter, ast.Name) + and node.iter.id in self.enum_class_vars + ): + self.used_names.update(self.enum_class_vars[node.iter.id]) + def visit_MatchClass(self, node): for kwd_attr in node.kwd_attrs: self.used_names.add(kwd_attr)