Skip to content

Commit

Permalink
feat: working Slurm benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugoch committed Sep 20, 2024
1 parent 36fff06 commit aa6a38f
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 56 deletions.
34 changes: 24 additions & 10 deletions extra/slurm/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,45 @@
import os
import subprocess

CPUS_PER_GPU = 11
CPUS_PER_GPU = 20
MEM_PER_CPU_GB = 11


def main():
models = [
('meta-llama/Meta-Llama-3.1-8B-Instruct', 1),
# ('meta-llama/Meta-Llama-3.1-8B-Instruct', 1),
('meta-llama/Meta-Llama-3.1-70B-Instruct', 4),
('mistralai/Mixtral-8x7B-Instruct-v0.1', 2),
# ('mistralai/Mixtral-8x7B-Instruct-v0.1', 2),
]
engines = ['tgi', 'vllm']
for model in models:
print(f"Submitting job for {model[0]}")
gpus = model[1]
cpus_per_task = gpus * CPUS_PER_GPU
mem_per_cpu = gpus * MEM_PER_CPU_GB
for engine in engines:
job_name = f'bench_{engine}_{model[0].replace("/", "_")}'
args = ['sbatch', '--cpus-per-task', str(cpus_per_task), '--mem-per-cpu', str(mem_per_cpu) + 'G', '--gpus',
str(gpus), '--nodes', '1',
'--job-name', job_name, f'{engine}.slurm']
token = os.environ.get('HF_TOKEN', '')
path = os.environ.get('PATH', '')
args = ['sbatch',
'--job-name', job_name,
'--output', f'/fsx/%u/logs/%x-%j.log',
'--time', '1:50:00',
'--qos', 'normal',
'--partition', 'hopper-prod',
'--gpus', str(gpus),
'--ntasks', '1',
'--cpus-per-task', str(cpus_per_task),
'--mem-per-cpu', str(MEM_PER_CPU_GB) + 'G',
'--nodes', '1',
':',
'--gpus', '1',
'--ntasks', '1',
'--cpus-per-task', str(CPUS_PER_GPU),
'--mem-per-cpu', str(MEM_PER_CPU_GB) + 'G',
'--nodes', '1',
f'{engine}.slurm']
env = os.environ.copy()
env['MODEL'] = model[0]
process = subprocess.run(args, capture_output=True,
env={'MODEL': model[0], 'HF_TOKEN': token, 'PATH': path})
env=env)
print(process.stdout.decode())
print(process.stderr.decode())
if process.returncode != 0:
Expand Down
41 changes: 19 additions & 22 deletions extra/slurm/tgi.slurm
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
#!/usr/bin/env bash
#SBATCH --job-name tgi-benchmark
#SBATCH --output /fsx/%u/logs/%x-%j.log
#SBATCH --gpus 1
#SBATCH --ntasks 2
#SBATCH --cpus-per-task 11
#SBATCH --mem-per-cpu 20G
#SBATCH --time 1:50:00
#SBATCH --partition hopper-prod
#SBATCH --qos normal
#SBATCH --partition hopper-prod
#SBATCH --gpus 1 --ntasks 1 --cpus-per-task 11 --mem-per-cpu 20G --nodes=1
#SBATCH hetjob
#SBATCH --gpus 1 --ntasks 1 --cpus-per-task 11 --mem-per-cpu 20G --nodes=1


if [ -z "$MODEL" ]; then
echo "MODEL environment variable is not set"
exit 1
fi

#model="meta-llama/Meta-Llama-3.1-8B-Instruct"
echo "Starting TGI benchmark for $MODEL"
export RUST_BACKTRACE=full
export RUST_LOG=text_generation_inference_benchmark=info
export PORT=8090

echo "Model will run on ${SLURM_JOB_NODELIST_HET_GROUP_0}:${PORT}"
echo "Benchmark will run on ${SLURM_JOB_NODELIST_HET_GROUP_1}"

# start TGI
srun -u \
srun --het-group=0 \
-u \
-n 1 \
--mem-per-cpu 20G \
--cpus-per-task 11 \
--gpus=1 \
--exact \
--container-image='ghcr.io#huggingface/text-generation-inference' \
--container-env=PORT \
--container-mounts="/scratch:/data" \
Expand All @@ -34,7 +34,7 @@ srun -u \
/usr/local/bin/text-generation-launcher \
--model-id $MODEL \
--max-concurrent-requests 512 \
--cuda-graphs="2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44,46,48,50,52,54,56,58,60,62,64,66,68,70,72,74,76,78,80,82,84,86,88,90,92"&
--cuda-graphs="1,8,16,24,32,40,48,56,64,72,80,88,96,104,112,120,128"&

# wait until /health is available, die after 5 minutes
timeout 300 bash -c "while [[ \"\$(curl -s -o /dev/null -w '%{http_code}' http://localhost:${PORT}/health)\" != \"200\" ]]; do sleep 1 && echo \"Waiting for TGI to start...\"; done" || exit 1
Expand All @@ -46,28 +46,25 @@ mkdir -p "${RESULTS_DIR}"
if [[ $exit_code != 124 ]]; then
# run benchmark
echo "Starting benchmark"
srun -u \
srun --het-group=1 \
-u \
-n 1 \
--mem-per-cpu 20G \
--cpus-per-task 11 \
--gpus 0 \
--exact \
--container-image='registry.hpc-cluster-hopper.hpc.internal.huggingface.tech/library/text-generation-inference-benchmark:latest' \
--container-image="registry.hpc-cluster-hopper.hpc.internal.huggingface.tech#library/text-generation-inference-benchmark:latest" \
--container-mounts="${RESULTS_DIR}:/opt/text-generation-inference-benchmark/results" \
--no-container-mount-home \
text-generation-inference-benchmark \
--tokenizer-name $MODEL \
--tokenizer-name "$MODEL" \
--max-vus 800 \
--url "http://localhost:${PORT}" \
--url "http://${SLURM_JOB_NODELIST_HET_GROUP_0}:${PORT}" \
--duration 30s \
--warmup 20s \
--warmup 30s \
--num-rates 2 \
--prompt-options "num_tokens=200,max_tokens=220,min_tokens=180,variance=10" \
--decode-options "num_tokens=200,max_tokens=220,min_tokens=180,variance=10" \
--no-console
fi

# stop TGI
scancel --signal=TERM "$SLURM_JOB_ID.0"
scancel --signal=TERM "$SLURM_JOB_ID+0"

echo "End of benchmark"
45 changes: 23 additions & 22 deletions extra/slurm/vllm.slurm
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#!/usr/bin/env bash
#SBATCH --job-name tgi-benchmark
#SBATCH --output /fsx/%u/logs/%x-%j.log
#SBATCH --gpus 1
#SBATCH --ntasks 2
#SBATCH --cpus-per-task 11
#SBATCH --mem-per-cpu 20G
#SBATCH --time 1:50:00
#SBATCH --partition hopper-prod
#SBATCH --qos normal
#SBATCH --partition hopper-prod
#SBATCH --gpus 1 --ntasks 1 --cpus-per-task 11 --mem-per-cpu 20G --nodes=1
#SBATCH hetjob
#SBATCH --gpus 1 --ntasks 1 --cpus-per-task 11 --mem-per-cpu 20G --nodes=1


if [ -z "$MODEL" ]; then
echo "MODEL environment variable is not set"
Expand All @@ -16,21 +16,25 @@ fi

echo "Starting vLLM benchmark for $MODEL"
export RUST_BACKTRACE=full
export RUST_LOG=text_generation_inference_benchmark=info;
export RUST_LOG=text_generation_inference_benchmark=info
export PORT=8090
# start TGI
srun -u \

echo "Model will run on ${SLURM_JOB_NODELIST_HET_GROUP_0}:${PORT}"
echo "Benchmark will run on ${SLURM_JOB_NODELIST_HET_GROUP_1}"

# start vLLM
srun --het-group=0 \
-u \
-n 1 \
--mem-per-cpu 20G \
--cpus-per-task 11 \
--gpus=1 \
--exact \
--container-image='vllm/vllm-openai:latest' \
--container-env=PORT \
--container-mounts="/scratch:/root/.cache/huggingface" \
--container-workdir='/usr/src' \
--no-container-mount-home \
python3 -m vllm.entrypoints.openai.api_server --model "${MODEL}" --port "${PORT}" &
python3 -m vllm.entrypoints.openai.api_server \
--model "${MODEL}" \
--port "${PORT}" \
--tensor-parallel-size "${SLURM_GPUS_ON_NODE}"&

# wait until /health is available, die after 5 minutes
timeout 300 bash -c "while [[ \"\$(curl -s -o /dev/null -w '%{http_code}' http://localhost:${PORT}/health)\" != \"200\" ]]; do sleep 1 && echo \"Waiting for vLLM to start...\"; done" || exit 1
Expand All @@ -42,28 +46,25 @@ mkdir -p "${RESULTS_DIR}"
if [[ $exit_code != 124 ]]; then
# run benchmark
echo "Starting benchmark"
srun -u \
srun --het-group=1 \
-u \
-n 1 \
--mem-per-cpu 20G \
--cpus-per-task 11 \
--gpus 0 \
--exact \
--container-image='registry.hpc-cluster-hopper.hpc.internal.huggingface.tech/library/text-generation-inference-benchmark:latest' \
--container-image="registry.hpc-cluster-hopper.hpc.internal.huggingface.tech#library/text-generation-inference-benchmark:latest" \
--container-mounts="${RESULTS_DIR}:/opt/text-generation-inference-benchmark/results" \
--no-container-mount-home \
text-generation-inference-benchmark \
--tokenizer-name "$MODEL" \
--max-vus 800 \
--url "http://localhost:${PORT}" \
--url "http://${SLURM_JOB_NODELIST_HET_GROUP_0}:${PORT}" \
--duration 30s \
--warmup 20s \
--warmup 30s \
--num-rates 2 \
--prompt-options "num_tokens=200,max_tokens=220,min_tokens=180,variance=10" \
--decode-options "num_tokens=200,max_tokens=220,min_tokens=180,variance=10" \
--no-console
fi

# stop TGI
scancel --signal=TERM "$SLURM_JOB_ID.0"
scancel --signal=TERM "$SLURM_JOB_ID+0"

echo "End of benchmark"
3 changes: 1 addition & 2 deletions plot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -87,7 +86,7 @@ def plot_inner(x_title, x_key, results, chart_title):
if __name__ == '__main__':
# list json files in results directory
data_files = {}
for filename in os.listdir('results'):
for filename in os.listdir('results/llama-70B'):
if filename.endswith('.json'):
data_files[filename.split('.')[0]] = f'results/{filename}'
plot(data_files)
1 change: 1 addition & 0 deletions src/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ pub struct BenchmarkConfig {
pub num_rates: u64,
pub prompt_options: Option<TokenizeOptions>,
pub decode_options: Option<TokenizeOptions>,
pub tokenizer: String,
}

impl BenchmarkConfig {
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pub async fn run(run_config: RunConfiguration,
num_rates: run_config.num_rates,
prompt_options: run_config.prompt_options.clone(),
decode_options: run_config.decode_options.clone(),
tokenizer: run_config.tokenizer_name.clone(),
};
config.validate()?;
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
Expand Down

0 comments on commit aa6a38f

Please sign in to comment.