Skip to content

Commit

Permalink
Add a3 ultra support
Browse files Browse the repository at this point in the history
  • Loading branch information
samos123 committed Jan 22, 2025
1 parent 185b1b5 commit 08d5a03
Show file tree
Hide file tree
Showing 6 changed files with 351 additions and 140 deletions.
52 changes: 51 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,62 @@ COPY . .
# GPU container spec. #
################################################################################

FROM base AS gpu
# This causes INTERNAL: No valid engine configs for Matmul error
# FROM base AS gpu
#
# RUN apt-get update && apt-get install -y ibverbs-utils
# # TODO(markblee): Support extras.
# ENV PIP_FIND_LINKS=https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# RUN pip install .[core,gpu]
# RUN pip install -U "jax[gpu]==0.4.37" "jax==0.4.37" "jaxlib==0.4.36" \
# -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# COPY . .

FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04 as gpu

# Copy from original base
RUN apt-get update
RUN apt-get install -y apt-transport-https ca-certificates gnupg curl \
gcc g++ python3 python3-venv ibverbs-utils
RUN ln -s /usr/bin/python3 /usr/bin/python

# Install git.
RUN apt-get install -y git

# Install gcloud. https://cloud.google.com/sdk/docs/install
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \
apt-get update -y && apt-get install google-cloud-cli -y

# Install screen and other utils for launch script.
RUN apt-get install -y jq screen ca-certificates

# Setup.
RUN mkdir -p /root
WORKDIR /root
# Introduce the minimum set of files for install.
COPY README.md README.md
COPY pyproject.toml pyproject.toml
RUN mkdir axlearn && touch axlearn/__init__.py
# Setup venv to suppress pip warnings.
ENV VIRTUAL_ENV=/opt/venv
RUN python -m venv $VIRTUAL_ENV
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# Install dependencies.
RUN pip install flit
RUN pip install --upgrade pip
# End copy original base


RUN apt update -y && apt-get install -y google-perftools glibc-tools

# TODO(markblee): Support extras.
ENV PIP_FIND_LINKS=https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
RUN pip install .[core,gpu]
COPY . .
RUN pip install -U "jax[gpu]==0.4.38" "jax==0.4.38" "jaxlib==0.4.38"
COPY . .


################################################################################
# Final target spec. #
Expand Down
Loading

0 comments on commit 08d5a03

Please sign in to comment.