Skip to content

Commit

Permalink
fix for multi gpu training
Browse files Browse the repository at this point in the history
  • Loading branch information
ddPn08 committed Mar 2, 2023
1 parent 08fcc7b commit ea57508
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
10 changes: 9 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2290,6 +2290,8 @@ def __call__(self, input_ids, attention_mask):
with torch.no_grad():
with accelerator.autocast():
for i, prompt in enumerate(prompts):
if not accelerator.is_main_process:
continue
prompt = prompt.strip()
if len(prompt) == 0 or prompt[0] == '#':
continue
Expand Down Expand Up @@ -2346,7 +2348,13 @@ def __call__(self, input_ids, attention_mask):
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])


print(f"prompt: {prompt}")
print(f"negative_prompt: {negative_prompt}")
print(f"height: {height}")
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]

ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
Expand Down
39 changes: 22 additions & 17 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def train(args):
# acceleratorを準備する
print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process

# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
Expand Down Expand Up @@ -175,12 +176,13 @@ def train(args):

# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
if is_main_process:
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")

# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)

# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
Expand Down Expand Up @@ -251,15 +253,17 @@ def train(args):
# 学習する
# TODO: find a way to handle total batch size when there are multiple datasets
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")

if is_main_process:
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")

# TODO refactor metadata creation and move to util
metadata = {
Expand Down Expand Up @@ -461,7 +465,8 @@ def train(args):
loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
if is_main_process:
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1)

metadata["ss_epoch"] = str(epoch+1)
Expand Down Expand Up @@ -573,9 +578,10 @@ def remove_old_func(old_epoch_no):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)

saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
if is_main_process:
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)

train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

Expand All @@ -584,7 +590,6 @@ def remove_old_func(old_epoch_no):
metadata["ss_epoch"] = str(num_train_epochs)
metadata["ss_training_finished_at"] = str(time.time())

is_main_process = accelerator.is_main_process
if is_main_process:
network = unwrap_model(network)

Expand Down

0 comments on commit ea57508

Please sign in to comment.