Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: load("@rules_cuda_redist_json//:redist.bzl", "rules_cuda_components") #286

Open
wants to merge 1 commit into
base: cloudhan/hermetic-ctk-2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion cuda/extensions.bzl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Entry point for extensions used by bzlmod."""

load("//cuda/private:compat.bzl", "components_mapping_compat")
load("//cuda/private:repositories.bzl", "cuda_component", "local_cuda")
load("//cuda/private:repositories.bzl", "cuda_component", "cuda_redist_json", "local_cuda")

cuda_component_tag = tag_class(attrs = {
"name": attr.string(mandatory = True, doc = "Repo name for the deliverable cuda_component"),
Expand Down Expand Up @@ -33,6 +33,30 @@ cuda_component_tag = tag_class(attrs = {
),
})

cuda_redist_json_tag = tag_class(attrs = {
"name": attr.string(mandatory = True, doc = "Repo name for the cuda_redist_json"),
"components": attr.string_list(mandatory = True, doc = "components to be used"),
"integrity": attr.string(
doc = "Expected checksum in Subresource Integrity format of the file downloaded. " +
"This must match the checksum of the file downloaded.",
),
"sha256": attr.string(
doc = "The expected SHA-256 of the file downloaded. " +
"This must match the SHA-256 of the file downloaded.",
),
"urls": attr.string_list(
doc = "A list of URLs to a file that will be made available to Bazel. " +
"Each entry must be a file, http or https URL. Redirections are followed. " +
"Authentication is not supported. " +
"URLs are tried in order until one succeeds, so you should list local mirrors first. " +
"If all downloads fail, the rule will fail.",
),
"version": attr.string(
doc = "Generate a URL by using the specified version." +
"This URL will be tried after all URLs specified in the `urls` attribute.",
),
})

cuda_toolkit_tag = tag_class(attrs = {
"name": attr.string(mandatory = True, doc = "Name for the toolchain repository", default = "local_cuda"),
"toolkit_path": attr.string(
Expand Down Expand Up @@ -70,17 +94,23 @@ def _impl(module_ctx):
# Toolchain configuration is only allowed in the root module, or in rules_cuda.
root, rules_cuda = _find_modules(module_ctx)
components = None
redist_jsons = None
toolkits = None
if root.tags.toolkit:
components = root.tags.component
redist_jsons = root.tags.redist_json
toolkits = root.tags.toolkit
else:
components = rules_cuda.tags.component
redist_jsons = rules_cuda.tags.redist_json
toolkits = rules_cuda.tags.toolkit

for component in components:
cuda_component(**_module_tag_to_dict(component))

for redist_json in redist_jsons:
cuda_redist_json(**_module_tag_to_dict(redist_json))

registrations = {}
for toolkit in toolkits:
if toolkit.name in registrations.keys():
Expand All @@ -97,6 +127,7 @@ toolchain = module_extension(
implementation = _impl,
tag_classes = {
"component": cuda_component_tag,
"redist_json": cuda_redist_json_tag,
"toolkit": cuda_toolkit_tag,
},
)
66 changes: 66 additions & 0 deletions cuda/private/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,72 @@ def default_components_mapping(components):
"""
return {c: "@local_cuda_" + c for c in components}

def _cuda_redist_json_impl(repository_ctx):
the_url = None # the url that successfully fetch redist json, we then use it to fetch deliverables
urls = [u for u in repository_ctx.attr.urls]

ver = repository_ctx.attr.version
if ver:
urls.append("https://developer.download.nvidia.com/compute/cuda/redist/redistrib_{}.json".format(ver))

if len(urls) == 0:
fail("`urls` or `version` must be specified.")

for url in urls:
ret = repository_ctx.download(
output = "redist.json",
integrity = repository_ctx.attr.integrity,
sha256 = repository_ctx.attr.sha256,
url = url,
)
if ret.success:
the_url = url
break

if the_url == None:
fail("Failed to retrieve the redist json file.")

# convert redist.json to list of spec (list of dicts with cuda_components attrs)
specs = []
redist = json.decode(repository_ctx.read("redist.json"))
for c in repository_ctx.attr.components:
c_full = FULL_COMPONENT_NAME[c]
os = None
if _is_linux(repository_ctx):
os = "linux"
elif _is_windows(repository_ctx):
os = "windows"

arch = "x86_64" # TODO: support cross compiling
platform = "{os}-{arch}".format(os = os, arch = arch)

payload = redist[c_full][platform]
payload_relative_path = payload["relative_path"]
payload_url = the_url.rsplit("/", 1)[0] + "/" + payload_relative_path
archive_name = payload_relative_path.rsplit("/", 1)[1].split("-archive.")[0] + "-archive"

specs.append({
"component_name": c,
"urls": [payload_url],
"sha256": payload["sha256"],
"strip_prefix": archive_name,
"version": redist[c_full]["version"],
})

template_helper.generate_redist_bzl(repository_ctx, specs)
repository_ctx.symlink(Label("//cuda/private:templates/BUILD.redist_json"), "BUILD")

cuda_redist_json = repository_rule(
implementation = _cuda_redist_json_impl,
attrs = {
"components": attr.string_list(mandatory = True),
"integrity": attr.string(mandatory = False),
"sha256": attr.string(mandatory = False),
"urls": attr.string_list(mandatory = False),
"version": attr.string(mandatory = False),
},
)

def rules_cuda_dependencies():
"""Populate the dependencies for rules_cuda. This will setup other bazel rules as workspace dependencies"""
maybe(
Expand Down
44 changes: 44 additions & 0 deletions cuda/private/template_helper.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,49 @@ def _generate_defs_bzl(repository_ctx, is_local_ctk):
}
repository_ctx.template("defs.bzl", tpl_label, substitutions = substitutions, executable = False)

def _generate_redist_bzl(repository_ctx, component_specs):
"""Generate `@rules_cuda_redist_json//:redist.bzl`

Args:
repository_ctx: repository_ctx
component_specs: list of dict, dict keys are component_name, urls, sha256, strip_prefix and version
"""

rules_cuda_components_body = []
mapping = {}

component_tpl = """cuda_component(
name = "{repo_name}",
component_name = "{component_name}",
sha256 = {sha256},
strip_prefix = {strip_prefix},
urls = {urls},
)"""

for spec in component_specs:
repo_name = "local_cuda_" + spec["component_name"]
version = spec.get("version", None)
if version != None:
repo_name = repo_name + "_v" + version

rules_cuda_components_body.append(
component_tpl.format(
repo_name = repo_name,
component_name = spec["component_name"],
sha256 = repr(spec["sha256"]),
strip_prefix = repr(spec["strip_prefix"]),
urls = repr(spec["urls"]),
),
)
mapping[spec["component_name"]] = "@" + repo_name

tpl_label = Label("//cuda/private:templates/redist.bzl.tpl")
substitutions = {
"%{rules_cuda_components_body}": "\n\n ".join(rules_cuda_components_body),
"%{components_mapping}": repr(mapping),
}
repository_ctx.template("redist.bzl", tpl_label, substitutions = substitutions, executable = False)

def _generate_toolchain_build(repository_ctx, cuda):
tpl_label = Label(
"//cuda/private:templates/BUILD.local_toolchain_" +
Expand Down Expand Up @@ -127,6 +170,7 @@ def _generate_toolchain_clang_build(repository_ctx, cuda, clang_path):
template_helper = struct(
generate_build = _generate_build,
generate_defs_bzl = _generate_defs_bzl,
generate_redist_bzl = _generate_redist_bzl,
generate_toolchain_build = _generate_toolchain_build,
generate_toolchain_clang_build = _generate_toolchain_clang_build,
)
18 changes: 18 additions & 0 deletions cuda/private/templates/BUILD.redist_json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package(
default_visibility = ["//visibility:public"],
)

filegroup(
name = "redist_bzl",
srcs = [":redist.bzl"],
)

filegroup(
name = "redist_json",
srcs = [":redist.json"],
)

exports_files([
"redist.bzl",
"redist.json",
])
14 changes: 14 additions & 0 deletions cuda/private/templates/redist.bzl.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
load("@rules_cuda//cuda:repositories.bzl", "cuda_component", "rules_cuda_toolchains")

def rules_cuda_components():
# See template_helper.generate_redist_bzl(...) for body generation logic
%{rules_cuda_components_body}

return %{components_mapping}

def rules_cuda_components_and_toolchains(register_toolchains = False):
components_mapping = rules_cuda_components()
rules_cuda_toolchains(
components_mapping= components_mapping,
register_toolchains = register_toolchains,
)
2 changes: 2 additions & 0 deletions cuda/repositories.bzl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
load(
"//cuda/private:repositories.bzl",
_cuda_component = "cuda_component",
_cuda_redist_json = "cuda_redist_json",
_default_components_mapping = "default_components_mapping",
_local_cuda = "local_cuda",
_rules_cuda_dependencies = "rules_cuda_dependencies",
Expand All @@ -10,6 +11,7 @@ load("//cuda/private:toolchain.bzl", _register_detected_cuda_toolchains = "regis

# rules
cuda_component = _cuda_component
cuda_redist_json = _cuda_redist_json
local_cuda = _local_cuda

# macros
Expand Down