Skip to content

Commit

Permalink
Account for aliases in alpha spec check (#37)
Browse files Browse the repository at this point in the history
* Account for aliases in alpha spec check

* Use AnchorPreservingLoader

* Add integration test

* Fix no-anchor nodes

* Review feedback
  • Loading branch information
KyleFromNVIDIA authored May 30, 2024
1 parent c31993c commit 5c86048
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 65 deletions.
139 changes: 95 additions & 44 deletions src/rapids_pre_commit_hooks/alpha_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def is_rapids_cuda_suffixed_package(name):
)


def check_package_spec(linter, args, node):
def check_package_spec(linter, args, anchors, used_anchors, node):
@total_ordering
class SpecPriority:
def __init__(self, spec):
Expand Down Expand Up @@ -111,42 +111,53 @@ def create_specifier_string(specifiers):
if req.name in RAPIDS_ALPHA_SPEC_PACKAGES or is_rapids_cuda_suffixed_package(
req.name
):
has_alpha_spec = any(str(s) == ALPHA_SPECIFIER for s in req.specifier)
if args.mode == "development" and not has_alpha_spec:
linter.add_warning(
(node.start_mark.index, node.end_mark.index),
f"add alpha spec for RAPIDS package {req.name}",
).add_replacement(
(node.start_mark.index, node.end_mark.index),
str(
req.name
+ create_specifier_string(
{str(s) for s in req.specifier} | {ALPHA_SPECIFIER}
)
),
)
elif args.mode == "release" and has_alpha_spec:
linter.add_warning(
(node.start_mark.index, node.end_mark.index),
f"remove alpha spec for RAPIDS package {req.name}",
).add_replacement(
(node.start_mark.index, node.end_mark.index),
str(
req.name
+ create_specifier_string(
{str(s) for s in req.specifier} - {ALPHA_SPECIFIER}
)
),
)


def check_packages(linter, args, node):
for key, value in anchors.items():
if value == node:
anchor = key
break
else:
anchor = None
if anchor not in used_anchors:
if anchor is not None:
used_anchors.add(anchor)
has_alpha_spec = any(str(s) == ALPHA_SPECIFIER for s in req.specifier)
if args.mode == "development" and not has_alpha_spec:
linter.add_warning(
(node.start_mark.index, node.end_mark.index),
f"add alpha spec for RAPIDS package {req.name}",
).add_replacement(
(node.start_mark.index, node.end_mark.index),
str(
(f"&{anchor} " if anchor else "")
+ req.name
+ create_specifier_string(
{str(s) for s in req.specifier} | {ALPHA_SPECIFIER},
)
),
)
elif args.mode == "release" and has_alpha_spec:
linter.add_warning(
(node.start_mark.index, node.end_mark.index),
f"remove alpha spec for RAPIDS package {req.name}",
).add_replacement(
(node.start_mark.index, node.end_mark.index),
str(
(f"&{anchor} " if anchor else "")
+ req.name
+ create_specifier_string(
{str(s) for s in req.specifier} - {ALPHA_SPECIFIER},
)
),
)


def check_packages(linter, args, anchors, used_anchors, node):
if node_has_type(node, "seq"):
for package_spec in node.value:
check_package_spec(linter, args, package_spec)
check_package_spec(linter, args, anchors, used_anchors, package_spec)


def check_common(linter, args, node):
def check_common(linter, args, anchors, used_anchors, node):
if node_has_type(node, "seq"):
for dependency_set in node.value:
if node_has_type(dependency_set, "map"):
Expand All @@ -155,10 +166,12 @@ def check_common(linter, args, node):
node_has_type(dependency_set_key, "str")
and dependency_set_key.value == "packages"
):
check_packages(linter, args, dependency_set_value)
check_packages(
linter, args, anchors, used_anchors, dependency_set_value
)


def check_matrices(linter, args, node):
def check_matrices(linter, args, anchors, used_anchors, node):
if node_has_type(node, "seq"):
for item in node.value:
if node_has_type(item, "map"):
Expand All @@ -167,10 +180,12 @@ def check_matrices(linter, args, node):
node_has_type(matrix_key, "str")
and matrix_key.value == "packages"
):
check_packages(linter, args, matrix_value)
check_packages(
linter, args, anchors, used_anchors, matrix_value
)


def check_specific(linter, args, node):
def check_specific(linter, args, anchors, used_anchors, node):
if node_has_type(node, "seq"):
for matrix_matcher in node.value:
if node_has_type(matrix_matcher, "map"):
Expand All @@ -179,30 +194,66 @@ def check_specific(linter, args, node):
node_has_type(matrix_matcher_key, "str")
and matrix_matcher_key.value == "matrices"
):
check_matrices(linter, args, matrix_matcher_value)
check_matrices(
linter, args, anchors, used_anchors, matrix_matcher_value
)


def check_dependencies(linter, args, node):
def check_dependencies(linter, args, anchors, used_anchors, node):
if node_has_type(node, "map"):
for _, dependencies_value in node.value:
if node_has_type(dependencies_value, "map"):
for dependency_key, dependency_value in dependencies_value.value:
if node_has_type(dependency_key, "str"):
if dependency_key.value == "common":
check_common(linter, args, dependency_value)
check_common(
linter, args, anchors, used_anchors, dependency_value
)
elif dependency_key.value == "specific":
check_specific(linter, args, dependency_value)
check_specific(
linter, args, anchors, used_anchors, dependency_value
)


def check_root(linter, args, node):
def check_root(linter, args, anchors, used_anchors, node):
if node_has_type(node, "map"):
for root_key, root_value in node.value:
if node_has_type(root_key, "str") and root_key.value == "dependencies":
check_dependencies(linter, args, root_value)
check_dependencies(linter, args, anchors, used_anchors, root_value)


class AnchorPreservingLoader(yaml.SafeLoader):
"""A SafeLoader that preserves the anchors for later reference. The anchors can
be found in the document_anchors member, which is a list of dictionaries, one
dictionary for each parsed document.
"""

def __init__(self, stream):
super().__init__(stream)
self.document_anchors = []

def compose_document(self):
# Drop the DOCUMENT-START event.
self.get_event()

# Compose the root node.
node = self.compose_node(None, None)

# Drop the DOCUMENT-END event.
self.get_event()

self.document_anchors.append(self.anchors)
self.anchors = {}
return node


def check_alpha_spec(linter, args):
check_root(linter, args, yaml.compose(linter.content))
loader = AnchorPreservingLoader(linter.content)
try:
root = loader.get_single_node()
finally:
loader.dispose()
check_root(linter, args, loader.document_anchors[0], set(), root)


def main():
Expand Down
Loading

0 comments on commit 5c86048

Please sign in to comment.