Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix multi GPUs training #472

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
12 changes: 9 additions & 3 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,7 +1766,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
if mem_eff_attn:
replace_unet_cross_attn_to_memory_efficient()
elif xformers:
replace_unet_cross_attn_to_xformers()
replace_unet_cross_attn_to_xformers(unet)


def replace_unet_cross_attn_to_memory_efficient():
Expand Down Expand Up @@ -1809,7 +1809,7 @@ def forward_flash_attn(self, x, context=None, mask=None):
diffusers.models.attention.CrossAttention.forward = forward_flash_attn


def replace_unet_cross_attn_to_xformers():
def replace_unet_cross_attn_to_xformers(unet):
print("Replace CrossAttention.forward to use xformers")
try:
import xformers.ops
Expand Down Expand Up @@ -1849,7 +1849,13 @@ def forward_xformers(self, x, context=None, mask=None):
out = self.to_out[1](out)
return out

diffusers.models.attention.CrossAttention.forward = forward_xformers
print("diffusers version:", diffusers.__version__)
if diffusers.__version__ >= "0.11.0":
# let xformer to decide witch ops is more suitable, reference _dispatch_fwd in xformers.ops
unet.enable_xformers_memory_efficient_attention()
elif hasattr(diffusers.models.attention, "CrossAttention") and \
hasattr(diffusers.models.attention.CrossAttention, "forward"):
diffusers.models.attention.CrossAttention.forward = forward_xformers


# endregion
Expand Down
4 changes: 2 additions & 2 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,9 +605,9 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh

class LoRANetwork(torch.nn.Module):
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数

import diffusers
# is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] if diffusers.__version__ < "0.15.0" else ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
Expand Down
33 changes: 23 additions & 10 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche

return logs


def train(args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
Expand All @@ -73,9 +72,15 @@ def train(args):
use_user_config = args.dataset_config is not None

if args.seed is None:
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
import psutil
ppid = os.getppid()
parent_process = psutil.Process(ppid)
if len(parent_process.children()) > 1:
args.seed = ppid
else:
args.seed = random.randint(0, 2**32)

set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args)

# データセットを準備する
Expand Down Expand Up @@ -138,8 +143,8 @@ def train(args):
# acceleratorを準備する
print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process

is_main_process = accelerator.is_main_process
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)

Expand Down Expand Up @@ -265,19 +270,19 @@ def train(args):

# acceleratorがなんかよろしくやってくれるらしい
if train_unet and train_text_encoder:
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, network, optimizer, train_dataloader_, lr_scheduler = accelerator.prepare(
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler
)
elif train_unet:
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader_, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
elif train_text_encoder:
text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, network, optimizer, train_dataloader_, lr_scheduler = accelerator.prepare(
text_encoder, network, optimizer, train_dataloader, lr_scheduler
)
else:
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
network, optimizer, train_dataloader_, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)

unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
Expand Down Expand Up @@ -569,6 +574,9 @@ def remove_model(old_ckpt_name):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)

mini_batch_size = int(args.train_batch_size) / accelerator.num_processes
mini_batch_offset = int(accelerator.process_index) * mini_batch_size

# training loop
for epoch in range(num_train_epochs):
if is_main_process:
Expand All @@ -581,6 +589,11 @@ def remove_model(old_ckpt_name):

for step, batch in enumerate(train_dataloader):
current_step.value = global_step
# cut mini batch
for k in batch.keys():
if batch[k] is None: continue
batch[k] = batch[k][int(mini_batch_offset):int(mini_batch_offset+mini_batch_size)]

with accelerator.accumulate(network):
# on_step_start(text_encoder, unet)

Expand Down Expand Up @@ -633,7 +646,7 @@ def remove_model(old_ckpt_name):
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
loss = loss * loss_weights

if args.min_snr_gamma:
Expand Down Expand Up @@ -733,7 +746,7 @@ def remove_model(old_ckpt_name):
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)

print("model saved.")


Expand Down