diff --git a/.github/workflows/docker-rocm.yaml b/.github/workflows/docker-rocm.yaml new file mode 100644 index 00000000..17a3f03f --- /dev/null +++ b/.github/workflows/docker-rocm.yaml @@ -0,0 +1,57 @@ +name: TritonBench Nightly ROCM Docker Build +on: + pull_request: + paths: + - .github/workflows/docker-rocm.yaml + - docker/tritonbench-rocm-nightly.dockerfile + workflow_dispatch: + inputs: + nightly_date: + description: "PyTorch nightly version" + required: false +env: + CONDA_ENV: "tritonbench" + SETUP_SCRIPT: "/workspace/setup_instance.sh" + +jobs: + build-push-docker: + if: ${{ github.repository_owner == 'pytorch-labs' }} + runs-on: 32-core-ubuntu + environment: docker-s3-upload + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + path: tritonbench + - name: Login to GitHub Container Registry + if: github.event_name != 'pull_request' + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: pytorch-labs + password: ${{ secrets.TRITONBENCH_ACCESS_TOKEN }} + - name: Build TritonBench nightly docker + run: | + set -x + export NIGHTLY_DATE="${{ github.event.inputs.nightly_date }}" + cd tritonbench/docker + # branch name is github.head_ref when triggered by pull_request + # and it is github.ref_name when triggered by workflow_dispatch + branch_name=${{ github.head_ref || github.ref_name }} + docker build . --build-arg TRITONBENCH_BRANCH="${branch_name}" --build-arg FORCE_DATE="${NIGHTLY_DATE}" \ + -f tritonbench-rocm-nightly.dockerfile -t ghcr.io/pytorch-labs/tritonbench:rocm-latest + # Extract pytorch version from the docker + PYTORCH_VERSION=$(docker run -e SETUP_SCRIPT="${SETUP_SCRIPT}" ghcr.io/pytorch-labs/tritonbench:rocm-latest bash -c '. "${SETUP_SCRIPT}"; python -c "import torch; print(torch.__version__)"') + export DOCKER_TAG=$(awk '{match($0, /dev[0-9]+/, arr); print arr[0]}' <<< "${PYTORCH_VERSION}") + docker tag ghcr.io/pytorch-labs/tritonbench:rocm-latest ghcr.io/pytorch-labs/tritonbench:rocm-${DOCKER_TAG} + - name: Push docker to remote + if: github.event_name != 'pull_request' + run: | + # Extract pytorch version from the docker + PYTORCH_VERSION=$(docker run -e SETUP_SCRIPT="${SETUP_SCRIPT}" ghcr.io/pytorch-labs/tritonbench:latest bash -c '. "${SETUP_SCRIPT}"; python -c "import torch; print(torch.__version__)"') + export DOCKER_TAG=$(awk '{match($0, /dev[0-9]+/, arr); print arr[0]}' <<< "${PYTORCH_VERSION}") + docker push ghcr.io/pytorch-labs/tritonbench:rocm-${DOCKER_TAG} + docker push ghcr.io/pytorch-labs/tritonbench:rocm-latest +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index c873f115..25b6bc5f 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -6,7 +6,7 @@ on: pull_request: paths: - .github/workflows/docker.yaml - - docker/*.dockerfile + - docker/tritonbench-nightly.dockerfile workflow_dispatch: inputs: nightly_date: diff --git a/README.md b/README.md index d2e5e7d1..160ec7dd 100644 --- a/README.md +++ b/README.md @@ -36,12 +36,14 @@ $ python run.py --op gemm We depend on the following projects as a source of customized Triton or CUTLASS kernels: -* (Required) [FBGEMM](https://github.com/pytorch/FBGEMM) -* (Required) [kernels](https://github.com/triton-lang/kernels) -* (Required) [generative-recommenders](https://github.com/facebookresearch/generative-recommenders) -* (Optional) [ThunderKittens](https://github.com/HazyResearch/ThunderKittens) -* (Optional) [cutlass-kernels](https://github.com/ColfaxResearch/cutlass-kernels) -* (Optional) [flash-attention](https://github.com/Dao-AILab/flash-attention) +* (CUDA, HIP) [kernels](https://github.com/triton-lang/kernels) +* (CUDA, HIP) [generative-recommenders](https://github.com/facebookresearch/generative-recommenders) +* (CUDA, HIP) [Liger-Kernel](https://github.com/linkedin/Liger-Kernel) +* (CUDA) [xformers](https://github.com/facebookresearch/xformers) +* (CUDA) [flash-attention](https://github.com/Dao-AILab/flash-attention) +* (CUDA) [FBGEMM](https://github.com/pytorch/FBGEMM) +* (CUDA) [ThunderKittens](https://github.com/HazyResearch/ThunderKittens) +* (CUDA) [cutlass-kernels](https://github.com/ColfaxResearch/cutlass-kernels) ## License diff --git a/docker/tritonbench-nightly.dockerfile b/docker/tritonbench-nightly.dockerfile index a2d09fde..71d2c5d9 100644 --- a/docker/tritonbench-nightly.dockerfile +++ b/docker/tritonbench-nightly.dockerfile @@ -9,6 +9,9 @@ ENV SETUP_SCRIPT=/workspace/setup_instance.sh ARG TRITONBENCH_BRANCH=${TRITONBENCH_BRANCH:-main} ARG FORCE_DATE=${FORCE_DATE} +# Install deps +RUN sudo apt install -y patch + # Checkout TritonBench and submodules RUN git clone --recurse-submodules -b "${TRITONBENCH_BRANCH}" --single-branch \ https://github.com/pytorch-labs/tritonbench /workspace/tritonbench @@ -22,13 +25,13 @@ RUN cd /workspace/tritonbench && \ RUN cd /workspace/tritonbench && \ . ${SETUP_SCRIPT} && \ - sudo python tools/cuda_utils.py --setup-cuda-softlink + sudo python -m tools.cuda_utils --setup-cuda-softlink # Install PyTorch nightly and verify the date is correct RUN cd /workspace/tritonbench && \ . ${SETUP_SCRIPT} && \ - python tools/cuda_utils.py --install-torch-deps && \ - python tools/cuda_utils.py --install-torch-nightly + python -m tools.cuda_utils --install-torch-deps && \ + python -m tools.cuda_utils --install-torch-nightly # Check the installed version of nightly if needed RUN cd /workspace/tritonbench && \ @@ -37,9 +40,9 @@ RUN cd /workspace/tritonbench && \ echo "torch version check skipped"; \ elif [ -z "${FORCE_DATE}" ]; then \ FORCE_DATE=$(date '+%Y%m%d') \ - python tools/cuda_utils.py --check-torch-nightly-version --force-date "${FORCE_DATE}"; \ + python -m tools.cuda_utils --check-torch-nightly-version --force-date "${FORCE_DATE}"; \ else \ - python tools/cuda_utils.py --check-torch-nightly-version --force-date "${FORCE_DATE}"; \ + python -m tools.cuda_utils --check-torch-nightly-version --force-date "${FORCE_DATE}"; \ fi # Tritonbench library build and test require libcuda.so.1 diff --git a/docker/tritonbench-rocm-nightly.dockerfile b/docker/tritonbench-rocm-nightly.dockerfile new file mode 100644 index 00000000..5fb5f3e6 --- /dev/null +++ b/docker/tritonbench-rocm-nightly.dockerfile @@ -0,0 +1,44 @@ +# Build ROCM base docker file +# We are not building AMD CI in a short term, but this could be useful +# for sharing benchmark results with AMD. +ARG BASE_IMAGE=rocm/pytorch:latest + +FROM ${BASE_IMAGE} + +ENV CONDA_ENV=pytorch +ENV CONDA_ENV_TRITON_MAIN=triton-main +ENV SETUP_SCRIPT=/workspace/setup_instance.sh +ARG TRITONBENCH_BRANCH=${TRITONBENCH_BRANCH:-main} +ARG FORCE_DATE=${FORCE_DATE} + +RUN mkdir -p /workspace; touch "${SETUP_SCRIPT}" + +RUN echo "\ +. /opt/conda/etc/profile.d/conda.sh\n\ +conda activate base\n\ +export CONDA_HOME=/opt/conda\n" > "${SETUP_SCRIPT}" + +RUN echo ". /workspace/setup_instance.sh\n" >> ${HOME}/.bashrc + +# Checkout TritonBench and submodules +RUN git clone --recurse-submodules -b "${TRITONBENCH_BRANCH}" --single-branch \ + https://github.com/pytorch-labs/tritonbench /workspace/tritonbench + +# Setup conda env +RUN cd /workspace/tritonbench && \ + . ${SETUP_SCRIPT} && \ + python tools/python_utils.py --create-conda-env ${CONDA_ENV} && \ + echo "if [ -z \${CONDA_ENV} ]; then export CONDA_ENV=${CONDA_ENV}; fi" >> "${SETUP_SCRIPT}" && \ + echo "conda activate \${CONDA_ENV}" >> "${SETUP_SCRIPT}" + + +# Install PyTorch nightly and verify the date is correct +RUN cd /workspace/tritonbench && \ + . ${SETUP_SCRIPT} && \ + python -m tools.rocm_utils --install-torch-deps && \ + python -m tools.rocm_utils --install-torch-nightly + + +# Install Tritonbench +RUN cd /workspace/tritonbench && \ + bash .ci/tritonbench/install.sh diff --git a/install.py b/install.py index 1be19df0..006e3a49 100644 --- a/install.py +++ b/install.py @@ -8,6 +8,7 @@ from tools.cuda_utils import CUDA_VERSION_MAP, DEFAULT_CUDA_VERSION from tools.git_utils import checkout_submodules from tools.python_utils import pip_install_requirements +from tools.torch_utils import is_hip logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -77,6 +78,13 @@ def install_liger(): subprocess.check_call(cmd) +def setup_hip(args: argparse.Namespace): + # We have to disable all third-parties that donot support hip/rocm + args.all = False + args.liger = True + args.hstu = True + + if __name__ == "__main__": parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--fbgemm", action="store_true", help="Install FBGEMM GPU") @@ -105,6 +113,9 @@ def install_liger(): parser.add_argument("--test", action="store_true", help="Run tests") args = parser.parse_args() + if args.all and is_hip(): + setup_hip(args) + # install framework dependencies pip_install_requirements("requirements.txt") # checkout submodules diff --git a/tools/cuda_utils.py b/tools/cuda_utils.py index ba46be49..59d514d1 100644 --- a/tools/cuda_utils.py +++ b/tools/cuda_utils.py @@ -1,14 +1,11 @@ -import importlib +import argparse import os import re import subprocess from pathlib import Path -from typing import Optional - # defines the default CUDA version to compile against DEFAULT_CUDA_VERSION = "12.4" -REPO_ROOT = Path(__file__).parent.parent CUDA_VERSION_MAP = { "12.4": { @@ -17,10 +14,6 @@ "jax": "jax[cuda12]", }, } -PIN_CMAKE_VERSION = "3.22.*" - -TORCH_NIGHTLY_PACKAGES = ["torch"] -BUILD_REQUIREMENTS_FILE = REPO_ROOT.joinpath("utils", "build_requirements.txt") def _nvcc_output_match(nvcc_output, target_cuda_version): @@ -94,6 +87,8 @@ def setup_cuda_softlink(cuda_version: str): def install_pytorch_nightly(cuda_version: str, env, dryrun=False): + from .torch_utils import TORCH_NIGHTLY_PACKAGES + uninstall_torch_cmd = ["pip", "uninstall", "-y"] uninstall_torch_cmd.extend(TORCH_NIGHTLY_PACKAGES) if dryrun: @@ -137,68 +132,7 @@ def install_torch_deps(cuda_version: str): subprocess.check_call(cmd) -def install_torch_build_deps(cuda_version: str): - install_torch_deps(cuda_version=cuda_version) - # Pin cmake version to stable - # See: https://github.com/pytorch/builder/pull/1269 - torch_build_deps = [ - "cffi", - "sympy", - "typing_extensions", - "future", - "six", - "dataclasses", - "tabulate", - "tqdm", - "mkl", - "mkl-include", - f"cmake={PIN_CMAKE_VERSION}", - ] - cmd = ["conda", "install", "-y"] + torch_build_deps - subprocess.check_call(cmd) - build_deps = ["ffmpeg"] - cmd = ["conda", "install", "-y"] + build_deps - subprocess.check_call(cmd) - # pip build deps - cmd = ["pip", "install", "-r"] + [str(BUILD_REQUIREMENTS_FILE.resolve())] - subprocess.check_call(cmd) - # conda forge deps - # ubuntu 22.04 comes with libstdcxx6 12.3.0 - # we need to install the same library version in conda to maintain ABI compatibility - conda_deps = ["libstdcxx-ng=12.3.0"] - cmd = ["conda", "install", "-y", "-c", "conda-forge"] + conda_deps - subprocess.check_call(cmd) - - -def get_torch_nightly_version(pkg_name: str): - pkg = importlib.import_module(pkg_name) - version = pkg.__version__ - regex = ".*dev([0-9]+).*" - date_str = re.match(regex, version).groups()[0] - pkg_ver = {"version": version, "date": date_str} - return (pkg_name, pkg_ver) - - -def check_torch_nightly_version(force_date: Optional[str] = None): - pkg_versions = dict(map(get_torch_nightly_version, TORCH_NIGHTLY_PACKAGES)) - pkg_dates = [x[1]["date"] for x in pkg_versions.items()] - if not len(set(pkg_dates)) == 1: - raise RuntimeError( - f"Found more than 1 dates in the torch nightly packages: {pkg_versions}." - ) - if force_date and not pkg_dates[0] == force_date: - raise RuntimeError( - f"Force date value {force_date}, but found torch packages {pkg_versions}." - ) - force_date_str = f"User force date {force_date}" if force_date else "" - print( - f"Installed consistent torch nightly packages: {pkg_versions}. {force_date_str}" - ) - - if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser() parser.add_argument( "--cudaver", @@ -240,9 +174,14 @@ def check_torch_nightly_version(force_date: Optional[str] = None): if args.install_torch_deps: install_torch_deps(cuda_version=args.cudaver) if args.install_torch_build_deps: - install_torch_build_deps(cuda_version=args.cudaver) + from .torch_utils import install_torch_build_deps + + install_torch_deps(cuda_version=args.cudaver) + install_torch_build_deps() if args.install_torch_nightly: install_pytorch_nightly(cuda_version=args.cudaver, env=os.environ) if args.check_torch_nightly_version: + from .torch_utils import check_torch_nightly_version + assert not args.install_torch_nightly, "Error: Can't run install torch nightly and check version in the same command." check_torch_nightly_version(args.force_date) diff --git a/tools/rocm_utils.py b/tools/rocm_utils.py new file mode 100644 index 00000000..0394af1a --- /dev/null +++ b/tools/rocm_utils.py @@ -0,0 +1,82 @@ +import argparse +import os +import subprocess + +# defines the default ROCM version to compile against +DEFAULT_ROCM_VERSION = "6.2" +ROCM_VERSION_MAP = { + "6.2": { + "pytorch_url": "rocm6.2", + }, +} + + +def install_torch_deps(): + # install other dependencies + torch_deps = [ + "requests", + "ninja", + "pyyaml", + "setuptools", + "gitpython", + "beautifulsoup4", + "regex", + ] + cmd = ["conda", "install", "-y"] + torch_deps + subprocess.check_call(cmd) + # conda forge deps + # ubuntu 22.04 comes with libstdcxx6 12.3.0 + # we need to install the same library version in conda + conda_deps = ["libstdcxx-ng=12.3.0"] + cmd = ["conda", "install", "-y", "-c", "conda-forge"] + conda_deps + subprocess.check_call(cmd) + + +def install_pytorch_nightly(rocm_version: str, env, dryrun=False): + from .torch_utils import TORCH_NIGHTLY_PACKAGES + + uninstall_torch_cmd = ["pip", "uninstall", "-y"] + uninstall_torch_cmd.extend(TORCH_NIGHTLY_PACKAGES) + if dryrun: + print(f"Uninstall pytorch: {uninstall_torch_cmd}") + else: + # uninstall multiple times to make sure the env is clean + for _loop in range(3): + subprocess.check_call(uninstall_torch_cmd) + pytorch_nightly_url = f"https://download.pytorch.org/whl/nightly/{ROCM_VERSION_MAP[rocm_version]['pytorch_url']}" + install_torch_cmd = ["pip", "install", "--pre", "--no-cache-dir"] + install_torch_cmd.extend(TORCH_NIGHTLY_PACKAGES) + install_torch_cmd.extend(["-i", pytorch_nightly_url]) + if dryrun: + print(f"Install pytorch nightly: {install_torch_cmd}") + else: + subprocess.check_call(install_torch_cmd, env=env) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--rocmver", + default=DEFAULT_ROCM_VERSION, + help="Specify rocm version.", + ) + parser.add_argument( + "--install-torch-deps", + action="store_true", + help="Install pytorch runtime dependencies", + ) + parser.add_argument( + "--install-torch-nightly", + action="store_true", + help="Install pytorch nightly", + ) + parser.add_argument( + "--dryrun", + action="store_true", + help="Dryrun the commands", + ) + args = parser.parse_args() + if args.install_torch_deps: + install_torch_deps() + if args.install_torch_nightly: + install_pytorch_nightly(args.rocmver, env=os.environ, dryrun=args.dryrun) diff --git a/tools/torch_utils.py b/tools/torch_utils.py new file mode 100644 index 00000000..d6e96ca1 --- /dev/null +++ b/tools/torch_utils.py @@ -0,0 +1,81 @@ +""" +CUDA/ROCM independent pytorch installation helpers. +""" + +import importlib +import re +import subprocess +from pathlib import Path + +from typing import Optional + +REPO_ROOT = Path(__file__).parent.parent.parent + +TORCH_NIGHTLY_PACKAGES = ["torch"] +PIN_CMAKE_VERSION = "3.22.*" +BUILD_REQUIREMENTS_FILE = REPO_ROOT.joinpath("utils", "build_requirements.txt") + + +def is_hip() -> bool: + import torch + + version = torch.__version__ + return "rocm" in version + + +def install_torch_build_deps(): + # Pin cmake version to stable + # See: https://github.com/pytorch/builder/pull/1269 + torch_build_deps = [ + "cffi", + "sympy", + "typing_extensions", + "future", + "six", + "dataclasses", + "tabulate", + "tqdm", + "mkl", + "mkl-include", + f"cmake={PIN_CMAKE_VERSION}", + ] + cmd = ["conda", "install", "-y"] + torch_build_deps + subprocess.check_call(cmd) + build_deps = ["ffmpeg"] + cmd = ["conda", "install", "-y"] + build_deps + subprocess.check_call(cmd) + # pip build deps + cmd = ["pip", "install", "-r"] + [str(BUILD_REQUIREMENTS_FILE.resolve())] + subprocess.check_call(cmd) + # conda forge deps + # ubuntu 22.04 comes with libstdcxx6 12.3.0 + # we need to install the same library version in conda to maintain ABI compatibility + conda_deps = ["libstdcxx-ng=12.3.0"] + cmd = ["conda", "install", "-y", "-c", "conda-forge"] + conda_deps + subprocess.check_call(cmd) + + +def get_torch_nightly_version(pkg_name: str): + pkg = importlib.import_module(pkg_name) + version = pkg.__version__ + regex = ".*dev([0-9]+).*" + date_str = re.match(regex, version).groups()[0] + pkg_ver = {"version": version, "date": date_str} + return (pkg_name, pkg_ver) + + +def check_torch_nightly_version(force_date: Optional[str] = None): + pkg_versions = dict(map(get_torch_nightly_version, TORCH_NIGHTLY_PACKAGES)) + pkg_dates = [x[1]["date"] for x in pkg_versions.items()] + if not len(set(pkg_dates)) == 1: + raise RuntimeError( + f"Found more than 1 dates in the torch nightly packages: {pkg_versions}." + ) + if force_date and not pkg_dates[0] == force_date: + raise RuntimeError( + f"Force date value {force_date}, but found torch packages {pkg_versions}." + ) + force_date_str = f"User force date {force_date}" if force_date else "" + print( + f"Installed consistent torch nightly packages: {pkg_versions}. {force_date_str}" + )