Skip to content

Commit

Permalink
Fix the docker build (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 authored Oct 11, 2024
1 parent e02e4ca commit 52cc618
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 27 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ on:
schedule:
# Push the nightly docker daily at 3 PM UTC
- cron: '0 15 * * *'
pull_request:
paths:
- docker/*.dockerfile
workflow_dispatch:
inputs:
nightly_date:
Expand All @@ -18,7 +21,6 @@ jobs:
build-push-docker:
if: ${{ github.repository_owner == 'pytorch-labs' }}
runs-on: ubuntu-latest
environment: docker-s3-upload
steps:
- name: Checkout
uses: actions/checkout@v3
Expand Down Expand Up @@ -56,4 +58,4 @@ jobs:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
cancel-in-progress: true
17 changes: 11 additions & 6 deletions docker/tritonbench-nightly.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ FROM ${BASE_IMAGE}

ENV CONDA_ENV=tritonbench
ENV SETUP_SCRIPT=/workspace/setup_instance.sh
ARG TRITONBENCH_BRANCH=${TORCHBENCH_BRANCH:-main}
ARG TRITONBENCH_BRANCH=${TRITONBENCH_BRANCH:-main}
ARG FORCE_DATE=${FORCE_DATE}

# Checkout TritonBench and submodules
Expand Down Expand Up @@ -41,12 +41,17 @@ RUN cd /workspace/tritonbench && \
python utils/cuda_utils.py --check-torch-nightly-version --force-date "${FORCE_DATE}"; \
fi

# Tritonbench library build and test require libcuda.so.1
# which is from NVIDIA driver
RUN sudo apt update && sudo apt-get install -y libnvidia-compute-550 patchelf

# Install Tritonbench
RUN cd /workspace/tritonbench && \
bash .ci/tritonbench/install.sh

# Test Tritonbench (libcuda.so.1 is required for fbgemm test, so install libnvidia-compute-550 as a hack)
RUN sudo apt update && sudo apt-get install -y libnvidia-compute-550 && \
cd /workspace/tritonbench && \
bash .ci/tritonbench/test-install.sh && \
sudo apt-get purge -y libnvidia-compute-550
# Test Tritonbench
RUN cd /workspace/tritonbench && \
bash .ci/tritonbench/test-install.sh

# Remove NVIDIA driver library - they are supposed to be mapped at runtime
RUN sudo apt-get purge -y libnvidia-compute-550
22 changes: 16 additions & 6 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,15 @@ def install_cutlass():
install_colfax_cutlass()


def install_fa():
def install_fa2():
FA2_PATH = REPO_PATH.joinpath("submodules", "flash-attention")
FA3_PATH = REPO_PATH.joinpath("submodules", "flash-attention", "hopper")
cmd = [sys.executable, "setup.py", "install"]
subprocess.check_call(cmd, cwd=str(FA2_PATH.resolve()))


def install_fa3():
FA3_PATH = REPO_PATH.joinpath("submodules", "flash-attention", "hopper")
cmd = [sys.executable, "setup.py", "install"]
subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve()))


Expand All @@ -83,7 +87,10 @@ def install_tk():
"--cutlass", action="store_true", help="Install optional CUTLASS kernels"
)
parser.add_argument(
"--fa", action="store_true", help="Install optional flash_attention kernels"
"--fa2", action="store_true", help="Install optional flash_attention 2 kernels"
)
parser.add_argument(
"--fa3", action="store_true", help="Install optional flash_attention 3 kernels"
)
parser.add_argument("--jax", action="store_true", help="Install jax nightly")
parser.add_argument("--tk", action="store_true", help="Install ThunderKittens")
Expand All @@ -99,9 +106,12 @@ def install_tk():
if args.fbgemm or args.all:
logger.info("[tritonbench] installing FBGEMM...")
install_fbgemm()
if args.fa or args.all:
logger.info("[tritonbench] installing fa2 and fa3...")
install_fa()
if args.fa2 or args.all:
logger.info("[tritonbench] installing fa2...")
install_fa2()
if args.fa3 or args.all:
logger.info("[tritonbench] installing fa3...")
install_fa3()
if args.cutlass or args.all:
logger.info("[tritonbench] installing cutlass-kernels...")
install_cutlass()
Expand Down
6 changes: 6 additions & 0 deletions utils/build_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# We need to pin numpy version to the same as the torch testing environment
# which still supports python 3.8
numpy==1.21.2; python_version < '3.11'
numpy==1.26.0; python_version >= '3.11'
psutil
pyyaml
77 changes: 64 additions & 13 deletions utils/cuda_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import importlib
import os
import re
Expand All @@ -20,7 +19,7 @@
}
PIN_CMAKE_VERSION = "3.22.*"

TORCHBENCH_TORCH_NIGHTLY_PACKAGES = ["torch", "torchvision", "torchaudio"]
TORCH_NIGHTLY_PACKAGES = ["torch"]
BUILD_REQUIREMENTS_FILE = REPO_ROOT.joinpath("utils", "build_requirements.txt")


Expand Down Expand Up @@ -96,7 +95,7 @@ def setup_cuda_softlink(cuda_version: str):

def install_pytorch_nightly(cuda_version: str, env, dryrun=False):
uninstall_torch_cmd = ["pip", "uninstall", "-y"]
uninstall_torch_cmd.extend(TORCHBENCH_TORCH_NIGHTLY_PACKAGES)
uninstall_torch_cmd.extend(TORCH_NIGHTLY_PACKAGES)
if dryrun:
print(f"Uninstall pytorch: {uninstall_torch_cmd}")
else:
Expand All @@ -105,7 +104,7 @@ def install_pytorch_nightly(cuda_version: str, env, dryrun=False):
subprocess.check_call(uninstall_torch_cmd)
pytorch_nightly_url = f"https://download.pytorch.org/whl/nightly/{CUDA_VERSION_MAP[cuda_version]['pytorch_url']}"
install_torch_cmd = ["pip", "install", "--pre", "--no-cache-dir"]
install_torch_cmd.extend(TORCHBENCH_TORCH_NIGHTLY_PACKAGES)
install_torch_cmd.extend(TORCH_NIGHTLY_PACKAGES)
install_torch_cmd.extend(["-i", pytorch_nightly_url])
if dryrun:
print(f"Install pytorch nightly: {install_torch_cmd}")
Expand Down Expand Up @@ -159,18 +158,11 @@ def install_torch_build_deps(cuda_version: str):
subprocess.check_call(cmd)
# conda forge deps
# ubuntu 22.04 comes with libstdcxx6 12.3.0
# we need to install the same library version in conda
# we need to install the same library version in conda to maintain ABI compatibility
conda_deps = ["libstdcxx-ng=12.3.0"]
cmd = ["conda", "install", "-y", "-c", "conda-forge"] + conda_deps
subprocess.check_call(cmd)


def install_torchbench_deps():
# tritonbench flash_attn depends on packaging to build
cmd = ["pip", "install", "unittest-xml-reporting", "boto3", "packaging"]
subprocess.check_call(cmd)


def get_torch_nightly_version(pkg_name: str):
pkg = importlib.import_module(pkg_name)
version = pkg.__version__
Expand All @@ -182,7 +174,7 @@ def get_torch_nightly_version(pkg_name: str):

def check_torch_nightly_version(force_date: Optional[str] = None):
pkg_versions = dict(
map(get_torch_nightly_version, TORCHBENCH_TORCH_NIGHTLY_PACKAGES)
map(get_torch_nightly_version, TORCH_NIGHTLY_PACKAGES)
)
pkg_dates = list(map(lambda x: x[1]["date"], pkg_versions.items()))
if not len(set(pkg_dates)) == 1:
Expand All @@ -198,3 +190,62 @@ def check_torch_nightly_version(force_date: Optional[str] = None):
f"Installed consistent torch nightly packages: {pkg_versions}. {force_date_str}"
)


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--cudaver",
default=DEFAULT_CUDA_VERSION,
help="Specify the default CUDA version",
)
parser.add_argument(
"--setup-cuda-softlink",
action="store_true",
help="Setup the softlink to /usr/local/cuda",
)
parser.add_argument(
"--install-torch-deps",
action="store_true",
help="Install pytorch runtime dependencies",
)
parser.add_argument(
"--install-torch-build-deps",
action="store_true",
help="Install pytorch build dependencies",
)
parser.add_argument(
"--install-torch-nightly", action="store_true", help="Install pytorch nightlies"
)
parser.add_argument(
"--install-torchbench-deps",
action="store_true",
help="Install torchbench conda dependencies",
)
parser.add_argument(
"--check-torch-nightly-version",
action="store_true",
help="Validate pytorch nightly package consistency",
)
parser.add_argument(
"--force-date",
type=str,
default=None,
help="Force Pytorch nightly release date version. Date string format: YYmmdd",
)
args = parser.parse_args()
if args.setup_cuda_softlink:
setup_cuda_softlink(cuda_version=args.cudaver)
if args.install_torch_deps:
install_torch_deps(cuda_version=args.cudaver)
if args.install_torch_build_deps:
install_torch_build_deps(cuda_version=args.cudaver)
if args.install_torch_nightly:
install_pytorch_nightly(cuda_version=args.cudaver, env=os.environ)
if args.install_torchbench_deps:
install_torchbench_deps()
if args.check_torch_nightly_version:
assert (
not args.install_torch_nightly
), "Error: Can't run install torch nightly and check version in the same command."
check_torch_nightly_version(args.force_date)
21 changes: 21 additions & 0 deletions utils/python_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,24 @@ def pip_install_requirements(
except Exception as e:
return (False, e)
return True, None

if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
"--pyver",
type=str,
default=DEFAULT_PYTHON_VERSION,
help="Specify the Python version.",
)
parser.add_argument(
"--create-conda-env",
type=str,
default=None,
help="Create conda environment of the default Python version.",
)
args = parser.parse_args()
if args.create_conda_env:
create_conda_env(args.pyver, args.create_conda_env)

0 comments on commit 52cc618

Please sign in to comment.