Skip to content

Commit

Permalink
Add scripts to remove redundant checkpoints
Browse files Browse the repository at this point in the history
Assure disk space
  • Loading branch information
research4pan committed Nov 20, 2023
1 parent f49123d commit 98530db
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 2 deletions.
3 changes: 3 additions & 0 deletions configs/ds_config_zero3.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
{
"checkpoint": {
"use_node_local_storage": true
},
"fp16": {
"enabled": "auto",
"loss_scale": 0,
Expand Down
4 changes: 2 additions & 2 deletions install.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/bin/bash

pip install -e .
pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple

gpu_state="$(nvidia-smi --query-gpu=name --format=csv,noheader)"
if [[ "${gpu_state}" == *"A100"* || "${gpu_state}" == *"A40"* || "${gpu_state}" == *"A6000"* ]]; then
pip install flash-attn==2.0.2
pip install flash-attn==2.0.2 -i https://pypi.tuna.tsinghua.edu.cn/simple
fi
148 changes: 148 additions & 0 deletions scripts/remove_redundant_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
'''
All rights reserved.
'''
# pylint: disable=too-many-locals
# pylint: disable=invalid-name
# pylint: disable=no-member
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-few-public-methods
# pylint: disable=too-many-arguments
# pylint: disable=self-assigning-variable
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
# pylint: disable=literal-comparison

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import pathlib
import shutil
import sys
import time

from datetime import datetime


def parse_argument(sys_argv):
"""Parses arguments from command line.
Args:
sys_argv: the list of arguments (strings) from command line.
Returns:
A struct whose member corresponds to the required (optional) variable.
For example,
```
args = parse_argument(['main.py' '--input', 'a.txt', '--num', '10'])
args.input # 'a.txt'
args.num # 10
```
"""
parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter
)

# Add arguments
parser.add_argument(
'--checkpoint_dir', type=str, required=True, default=None,
help='directory with structure {checkpoint_dir}/checkpoint-*/*'
)
parser.add_argument(
'--save_total_limit', type=int, default=11,
help='maximum number of existing checkpoints',
)
parser.add_argument(
'--save_frequency', type=int, default=1000,
help='checkpoint-K with K % {save_frequency} == 0 will be retained',
)
parser.add_argument(
'--sleep_time', type=int, default=-1,
help='how frequently check the directory, -1 means check only once',
)
parser.add_argument(
'--debug_mode', type=int, default=0,
help='when this value is 1, only simulation results will be printed',
)

# Parses from commandline
args = parser.parse_args(sys_argv[1:])

return args


def sub(list_a, list_b):
return list(set(list_a) - set(list_b))


def remove_redundant_checkpoints(args):
checkpoint_dir = pathlib.Path(args.checkpoint_dir)

# First scan for counting
checkpoint_id_list = []
for checkpoint in checkpoint_dir.iterdir():
if not checkpoint.is_dir():
continue

try:
checkpoint_id = int(checkpoint.name.lstrip('checkpoint-'))
except:
continue

checkpoint_id_list.append(checkpoint_id)

# Retains K % save_frequency == 0 checkpoints
retain_id_list = []
num_remain_slots = args.save_total_limit
for checkpoint_id in reversed(sorted(checkpoint_id_list)):
if num_remain_slots <= 0:
break
if checkpoint_id % args.save_frequency == 0:
retain_id_list.append(checkpoint_id)
num_remain_slots -= 1
continue

checkpoint_id_list = sub(checkpoint_id_list, retain_id_list)

# If there is still space, retain latest checkpoints
num_remain_slots = args.save_total_limit - len(retain_id_list)
if num_remain_slots > 0:
latest_id_list = sorted(checkpoint_id_list)[-num_remain_slots:]
retain_id_list += latest_id_list
checkpoint_id_list = sub(checkpoint_id_list, latest_id_list)

# Removes
remove_id_list = checkpoint_id_list
if args.debug_mode:
for checkpoint_id in sorted(remove_id_list):
checkpoint_path = checkpoint_dir / f'checkpoint-{checkpoint_id}'
print('-', checkpoint_path)

for checkpoint_id in sorted(retain_id_list):
checkpoint_path = checkpoint_dir / f'checkpoint-{checkpoint_id}'
print('+', checkpoint_path)
else:
for checkpoint_id in sorted(remove_id_list):
checkpoint_path = checkpoint_dir / f'checkpoint-{checkpoint_id}'
shutil.rmtree(checkpoint_path)


def main():
"""Removes redundant checkpoints"""
args = parse_argument(sys.argv)
print('#################################################')
print('args =', str(args))

if args.sleep_time < 0:
remove_redundant_checkpoints(args)
else:
while True:
remove_redundant_checkpoints(args)
time.sleep(args.sleep_time)


if __name__ == '__main__':
main()

0 comments on commit 98530db

Please sign in to comment.