diff --git a/library/train_util.py b/library/train_util.py index 75176e130..784860f92 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2290,6 +2290,8 @@ def __call__(self, input_ids, attention_mask): with torch.no_grad(): with accelerator.autocast(): for i, prompt in enumerate(prompts): + if not accelerator.is_main_process: + continue prompt = prompt.strip() if len(prompt) == 0 or prompt[0] == '#': continue @@ -2346,7 +2348,13 @@ def __call__(self, input_ids, attention_mask): prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) if negative_prompt is not None: negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - + + print(f"prompt: {prompt}") + print(f"negative_prompt: {negative_prompt}") + print(f"height: {height}") + print(f"width: {width}") + print(f"sample_steps: {sample_steps}") + print(f"scale: {scale}") image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0] ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime()) diff --git a/train_network.py b/train_network.py index 7aee65146..f45281041 100644 --- a/train_network.py +++ b/train_network.py @@ -106,6 +106,7 @@ def train(args): # acceleratorを準備する print("prepare accelerator") accelerator, unwrap_model = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) @@ -175,12 +176,13 @@ def train(args): # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * len(train_dataloader) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes) + if is_main_process: + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps, num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする @@ -251,15 +253,17 @@ def train(args): # 学習する # TODO: find a way to handle total batch size when there are multiple datasets total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + if is_main_process: + print("running training / 学習開始") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") # TODO refactor metadata creation and move to util metadata = { @@ -461,7 +465,8 @@ def train(args): loss_list = [] loss_total = 0.0 for epoch in range(num_train_epochs): - print(f"epoch {epoch+1}/{num_train_epochs}") + if is_main_process: + print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset_group.set_current_epoch(epoch + 1) metadata["ss_epoch"] = str(epoch+1) @@ -573,9 +578,10 @@ def remove_old_func(old_epoch_no): print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) - saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) - if saving and args.save_state: - train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + if is_main_process: + saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + if saving and args.save_state: + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) @@ -584,7 +590,6 @@ def remove_old_func(old_epoch_no): metadata["ss_epoch"] = str(num_train_epochs) metadata["ss_training_finished_at"] = str(time.time()) - is_main_process = accelerator.is_main_process if is_main_process: network = unwrap_model(network)