diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 2c620da7..2c4b2adc 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -3,6 +3,9 @@ on: schedule: # Push the nightly docker daily at 3 PM UTC - cron: '0 15 * * *' + pull_request: + paths: + - docker/*.dockerfile workflow_dispatch: inputs: nightly_date: @@ -18,7 +21,6 @@ jobs: build-push-docker: if: ${{ github.repository_owner == 'pytorch-labs' }} runs-on: ubuntu-latest - environment: docker-s3-upload steps: - name: Checkout uses: actions/checkout@v3 @@ -56,4 +58,4 @@ jobs: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} - cancel-in-progress: true \ No newline at end of file + cancel-in-progress: true diff --git a/docker/tritonbench-nightly.dockerfile b/docker/tritonbench-nightly.dockerfile index 13beeffe..de7abc39 100644 --- a/docker/tritonbench-nightly.dockerfile +++ b/docker/tritonbench-nightly.dockerfile @@ -5,7 +5,7 @@ FROM ${BASE_IMAGE} ENV CONDA_ENV=tritonbench ENV SETUP_SCRIPT=/workspace/setup_instance.sh -ARG TRITONBENCH_BRANCH=${TORCHBENCH_BRANCH:-main} +ARG TRITONBENCH_BRANCH=${TRITONBENCH_BRANCH:-main} ARG FORCE_DATE=${FORCE_DATE} # Checkout TritonBench and submodules @@ -41,12 +41,17 @@ RUN cd /workspace/tritonbench && \ python utils/cuda_utils.py --check-torch-nightly-version --force-date "${FORCE_DATE}"; \ fi +# Tritonbench library build and test require libcuda.so.1 +# which is from NVIDIA driver +RUN sudo apt update && sudo apt-get install -y libnvidia-compute-550 patchelf + # Install Tritonbench RUN cd /workspace/tritonbench && \ bash .ci/tritonbench/install.sh -# Test Tritonbench (libcuda.so.1 is required for fbgemm test, so install libnvidia-compute-550 as a hack) -RUN sudo apt update && sudo apt-get install -y libnvidia-compute-550 && \ - cd /workspace/tritonbench && \ - bash .ci/tritonbench/test-install.sh && \ - sudo apt-get purge -y libnvidia-compute-550 +# Test Tritonbench +RUN cd /workspace/tritonbench && \ + bash .ci/tritonbench/test-install.sh + +# Remove NVIDIA driver library - they are supposed to be mapped at runtime +RUN sudo apt-get purge -y libnvidia-compute-550 diff --git a/install.py b/install.py index 79d43fc1..9ba0b341 100644 --- a/install.py +++ b/install.py @@ -56,11 +56,15 @@ def install_cutlass(): install_colfax_cutlass() -def install_fa(): +def install_fa2(): FA2_PATH = REPO_PATH.joinpath("submodules", "flash-attention") - FA3_PATH = REPO_PATH.joinpath("submodules", "flash-attention", "hopper") cmd = [sys.executable, "setup.py", "install"] subprocess.check_call(cmd, cwd=str(FA2_PATH.resolve())) + + +def install_fa3(): + FA3_PATH = REPO_PATH.joinpath("submodules", "flash-attention", "hopper") + cmd = [sys.executable, "setup.py", "install"] subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve())) @@ -83,7 +87,10 @@ def install_tk(): "--cutlass", action="store_true", help="Install optional CUTLASS kernels" ) parser.add_argument( - "--fa", action="store_true", help="Install optional flash_attention kernels" + "--fa2", action="store_true", help="Install optional flash_attention 2 kernels" + ) + parser.add_argument( + "--fa3", action="store_true", help="Install optional flash_attention 3 kernels" ) parser.add_argument("--jax", action="store_true", help="Install jax nightly") parser.add_argument("--tk", action="store_true", help="Install ThunderKittens") @@ -99,9 +106,12 @@ def install_tk(): if args.fbgemm or args.all: logger.info("[tritonbench] installing FBGEMM...") install_fbgemm() - if args.fa or args.all: - logger.info("[tritonbench] installing fa2 and fa3...") - install_fa() + if args.fa2 or args.all: + logger.info("[tritonbench] installing fa2...") + install_fa2() + if args.fa3 or args.all: + logger.info("[tritonbench] installing fa3...") + install_fa3() if args.cutlass or args.all: logger.info("[tritonbench] installing cutlass-kernels...") install_cutlass() diff --git a/utils/build_requirements.txt b/utils/build_requirements.txt new file mode 100644 index 00000000..36490cca --- /dev/null +++ b/utils/build_requirements.txt @@ -0,0 +1,6 @@ +# We need to pin numpy version to the same as the torch testing environment +# which still supports python 3.8 +numpy==1.21.2; python_version < '3.11' +numpy==1.26.0; python_version >= '3.11' +psutil +pyyaml \ No newline at end of file diff --git a/utils/cuda_utils.py b/utils/cuda_utils.py index d41f71f8..02bfcdb4 100644 --- a/utils/cuda_utils.py +++ b/utils/cuda_utils.py @@ -1,4 +1,3 @@ -import argparse import importlib import os import re @@ -20,7 +19,7 @@ } PIN_CMAKE_VERSION = "3.22.*" -TORCHBENCH_TORCH_NIGHTLY_PACKAGES = ["torch", "torchvision", "torchaudio"] +TORCH_NIGHTLY_PACKAGES = ["torch"] BUILD_REQUIREMENTS_FILE = REPO_ROOT.joinpath("utils", "build_requirements.txt") @@ -96,7 +95,7 @@ def setup_cuda_softlink(cuda_version: str): def install_pytorch_nightly(cuda_version: str, env, dryrun=False): uninstall_torch_cmd = ["pip", "uninstall", "-y"] - uninstall_torch_cmd.extend(TORCHBENCH_TORCH_NIGHTLY_PACKAGES) + uninstall_torch_cmd.extend(TORCH_NIGHTLY_PACKAGES) if dryrun: print(f"Uninstall pytorch: {uninstall_torch_cmd}") else: @@ -105,7 +104,7 @@ def install_pytorch_nightly(cuda_version: str, env, dryrun=False): subprocess.check_call(uninstall_torch_cmd) pytorch_nightly_url = f"https://download.pytorch.org/whl/nightly/{CUDA_VERSION_MAP[cuda_version]['pytorch_url']}" install_torch_cmd = ["pip", "install", "--pre", "--no-cache-dir"] - install_torch_cmd.extend(TORCHBENCH_TORCH_NIGHTLY_PACKAGES) + install_torch_cmd.extend(TORCH_NIGHTLY_PACKAGES) install_torch_cmd.extend(["-i", pytorch_nightly_url]) if dryrun: print(f"Install pytorch nightly: {install_torch_cmd}") @@ -159,18 +158,11 @@ def install_torch_build_deps(cuda_version: str): subprocess.check_call(cmd) # conda forge deps # ubuntu 22.04 comes with libstdcxx6 12.3.0 - # we need to install the same library version in conda + # we need to install the same library version in conda to maintain ABI compatibility conda_deps = ["libstdcxx-ng=12.3.0"] cmd = ["conda", "install", "-y", "-c", "conda-forge"] + conda_deps subprocess.check_call(cmd) - -def install_torchbench_deps(): - # tritonbench flash_attn depends on packaging to build - cmd = ["pip", "install", "unittest-xml-reporting", "boto3", "packaging"] - subprocess.check_call(cmd) - - def get_torch_nightly_version(pkg_name: str): pkg = importlib.import_module(pkg_name) version = pkg.__version__ @@ -182,7 +174,7 @@ def get_torch_nightly_version(pkg_name: str): def check_torch_nightly_version(force_date: Optional[str] = None): pkg_versions = dict( - map(get_torch_nightly_version, TORCHBENCH_TORCH_NIGHTLY_PACKAGES) + map(get_torch_nightly_version, TORCH_NIGHTLY_PACKAGES) ) pkg_dates = list(map(lambda x: x[1]["date"], pkg_versions.items())) if not len(set(pkg_dates)) == 1: @@ -198,3 +190,62 @@ def check_torch_nightly_version(force_date: Optional[str] = None): f"Installed consistent torch nightly packages: {pkg_versions}. {force_date_str}" ) + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument( + "--cudaver", + default=DEFAULT_CUDA_VERSION, + help="Specify the default CUDA version", + ) + parser.add_argument( + "--setup-cuda-softlink", + action="store_true", + help="Setup the softlink to /usr/local/cuda", + ) + parser.add_argument( + "--install-torch-deps", + action="store_true", + help="Install pytorch runtime dependencies", + ) + parser.add_argument( + "--install-torch-build-deps", + action="store_true", + help="Install pytorch build dependencies", + ) + parser.add_argument( + "--install-torch-nightly", action="store_true", help="Install pytorch nightlies" + ) + parser.add_argument( + "--install-torchbench-deps", + action="store_true", + help="Install torchbench conda dependencies", + ) + parser.add_argument( + "--check-torch-nightly-version", + action="store_true", + help="Validate pytorch nightly package consistency", + ) + parser.add_argument( + "--force-date", + type=str, + default=None, + help="Force Pytorch nightly release date version. Date string format: YYmmdd", + ) + args = parser.parse_args() + if args.setup_cuda_softlink: + setup_cuda_softlink(cuda_version=args.cudaver) + if args.install_torch_deps: + install_torch_deps(cuda_version=args.cudaver) + if args.install_torch_build_deps: + install_torch_build_deps(cuda_version=args.cudaver) + if args.install_torch_nightly: + install_pytorch_nightly(cuda_version=args.cudaver, env=os.environ) + if args.install_torchbench_deps: + install_torchbench_deps() + if args.check_torch_nightly_version: + assert ( + not args.install_torch_nightly + ), "Error: Can't run install torch nightly and check version in the same command." + check_torch_nightly_version(args.force_date) diff --git a/utils/python_utils.py b/utils/python_utils.py index 0e1d0841..97af8003 100644 --- a/utils/python_utils.py +++ b/utils/python_utils.py @@ -67,3 +67,24 @@ def pip_install_requirements( except Exception as e: return (False, e) return True, None + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--pyver", + type=str, + default=DEFAULT_PYTHON_VERSION, + help="Specify the Python version.", + ) + parser.add_argument( + "--create-conda-env", + type=str, + default=None, + help="Create conda environment of the default Python version.", + ) + args = parser.parse_args() + if args.create_conda_env: + create_conda_env(args.pyver, args.create_conda_env) +