diff --git a/library/train_util.py b/library/train_util.py index 8c6e34371..7e104adad 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1766,7 +1766,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio if mem_eff_attn: replace_unet_cross_attn_to_memory_efficient() elif xformers: - replace_unet_cross_attn_to_xformers() + replace_unet_cross_attn_to_xformers(unet) def replace_unet_cross_attn_to_memory_efficient(): @@ -1809,7 +1809,7 @@ def forward_flash_attn(self, x, context=None, mask=None): diffusers.models.attention.CrossAttention.forward = forward_flash_attn -def replace_unet_cross_attn_to_xformers(): +def replace_unet_cross_attn_to_xformers(unet): print("Replace CrossAttention.forward to use xformers") try: import xformers.ops @@ -1849,7 +1849,13 @@ def forward_xformers(self, x, context=None, mask=None): out = self.to_out[1](out) return out - diffusers.models.attention.CrossAttention.forward = forward_xformers + print("diffusers version:", diffusers.__version__) + if diffusers.__version__ >= "0.11.0": + # let xformer to decide witch ops is more suitable, reference _dispatch_fwd in xformers.ops + unet.enable_xformers_memory_efficient_attention() + elif hasattr(diffusers.models.attention, "CrossAttention") and \ + hasattr(diffusers.models.attention.CrossAttention, "forward"): + diffusers.models.attention.CrossAttention.forward = forward_xformers # endregion diff --git a/networks/lora.py b/networks/lora.py index 353b1f5ac..5e4c2e080 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -605,9 +605,9 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh class LoRANetwork(torch.nn.Module): NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 - + import diffusers # is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;) - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] if diffusers.__version__ < "0.15.0" else ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" diff --git a/train_network.py b/train_network.py index 5c4d5ad19..28ab2c1c5 100644 --- a/train_network.py +++ b/train_network.py @@ -61,7 +61,6 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche return logs - def train(args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -73,9 +72,15 @@ def train(args): use_user_config = args.dataset_config is not None if args.seed is None: - args.seed = random.randint(0, 2**32) - set_seed(args.seed) + import psutil + ppid = os.getppid() + parent_process = psutil.Process(ppid) + if len(parent_process.children()) > 1: + args.seed = ppid + else: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) tokenizer = train_util.load_tokenizer(args) # データセットを準備する @@ -138,8 +143,8 @@ def train(args): # acceleratorを準備する print("prepare accelerator") accelerator, unwrap_model = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process + is_main_process = accelerator.is_main_process # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) @@ -265,19 +270,19 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい if train_unet and train_text_encoder: - unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, network, optimizer, train_dataloader_, lr_scheduler = accelerator.prepare( unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler ) elif train_unet: - unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, network, optimizer, train_dataloader_, lr_scheduler = accelerator.prepare( unet, network, optimizer, train_dataloader, lr_scheduler ) elif train_text_encoder: - text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, network, optimizer, train_dataloader_, lr_scheduler = accelerator.prepare( text_encoder, network, optimizer, train_dataloader, lr_scheduler ) else: - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) + network, optimizer, train_dataloader_, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) unet.requires_grad_(False) unet.to(accelerator.device, dtype=weight_dtype) @@ -569,6 +574,9 @@ def remove_model(old_ckpt_name): print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) + mini_batch_size = int(args.train_batch_size) / accelerator.num_processes + mini_batch_offset = int(accelerator.process_index) * mini_batch_size + # training loop for epoch in range(num_train_epochs): if is_main_process: @@ -581,6 +589,11 @@ def remove_model(old_ckpt_name): for step, batch in enumerate(train_dataloader): current_step.value = global_step + # cut mini batch + for k in batch.keys(): + if batch[k] is None: continue + batch[k] = batch[k][int(mini_batch_offset):int(mini_batch_offset+mini_batch_size)] + with accelerator.accumulate(network): # on_step_start(text_encoder, unet) @@ -633,7 +646,7 @@ def remove_model(old_ckpt_name): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) - loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight loss = loss * loss_weights if args.min_snr_gamma: @@ -733,7 +746,7 @@ def remove_model(old_ckpt_name): if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) - + print("model saved.")