-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
1489 lines (1273 loc) · 67.5 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Install bitsandbytes:
# `nvcc --version` to get CUDA version.
# `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA.
# Example Usage:
# Single GPU: torchrun --nproc_per_node=1 trainer/diffusers_trainer.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=1 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
# Multiple GPUs: torchrun --nproc_per_node=N trainer/diffusers_trainer.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
import argparse
import copy
import socket
import torch
import torchvision
import transformers
import diffusers
import os
import glob
import random
import tqdm
import resource
import psutil
import pynvml
import wandb
import gc
import time
import itertools
import numpy as np
import PIL
import json
import re
import traceback
import gc
import shutil
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.transforms import functional as visionF
try:
pynvml.nvmlInit()
except pynvml.nvml.NVMLError_LibraryNotFound:
pynvml = None
from typing import Iterable, Optional
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMScheduler, DDIMScheduler, StableDiffusionXLPipeline, EulerDiscreteScheduler
from diffusers.optimization import get_scheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
from PIL import Image, ImageOps
from PIL.Image import Image as Img
from collections import defaultdict
from typing import Dict, List, Generator, Tuple
from scipy.interpolate import interp1d
torch.backends.cuda.matmul.allow_tf32 = True
# defaults should be good for everyone
# TODO: add custom VAE support. should be simple with diffusers
bool_t = lambda x: x.lower() in ['true', 'yes', '1']
parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner')
parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory')
parser.add_argument('--resume', type=str, default=None, help='The path to the checkpoint to resume from. If not specified, will create a new run.')
parser.add_argument('--run_name', type=str, default=None, required=True, help='Name of the finetune run.')
parser.add_argument('--dataset', type=str, default=None, required=True, help='The path to the dataset to use for finetuning.')
parser.add_argument('--num_buckets', type=int, default=20, help='The number of buckets.')
parser.add_argument('--bucket_side_min', type=int, default=256, help='The minimum side length of a bucket.')
parser.add_argument('--bucket_side_max', type=int, default=1536, help='The maximum side length of a bucket.')
parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate')
parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train for')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
parser.add_argument('--gradient_accumulation', type=int, default=2, help='gradient_accumulation size default 2')
parser.add_argument('--use_ema', type=bool_t, default='False', help='Use EMA for finetuning')
parser.add_argument('--ucg', type=float, default=0.1, help='Percentage chance of dropping out the text condition per batch. Ranges from 0.0 to 1.0 where 1.0 means 100% text condition dropout.') # 10% dropout probability
parser.add_argument('--gradient_checkpointing', dest='gradient_checkpointing', type=bool_t, default='False', help='Enable gradient checkpointing')
parser.add_argument('--use_8bit_adam', dest='use_8bit_adam', type=bool_t, default='False', help='Use 8-bit Adam optimizer')
parser.add_argument('--adam_beta1', type=float, default=0.9, help='Adam beta1')
parser.add_argument('--adam_beta2', type=float, default=0.999, help='Adam beta2')
parser.add_argument('--adam_weight_decay', type=float, default=0, help='Adam weight decay')
parser.add_argument('--adam_epsilon', type=float, default=1e-08, help='Adam epsilon')
parser.add_argument('--lr_scheduler', type=str, default='cosine', help='Learning rate scheduler [`cosine`, `linear`, `constant`]')
parser.add_argument('--lr_scheduler_warmup', type=float, default=0.001, help='Learning rate scheduler warmup steps. This is a percentage of the total number of steps in the training run. 0.1 means 10 percent of the total number of steps.')
parser.add_argument('--seed', type=int, default=42, help='Seed for random number generator, this is to be used for reproduceability purposes.')
parser.add_argument('--output_path', type=str, default='./output', help='Root path for all outputs.')
parser.add_argument('--save_steps', type=int, default=500, help='Number of steps to save checkpoints at.')
parser.add_argument('--resolution', type=int, default=512, help='Image resolution to train against. Lower res images will be scaled up to this resolution and higher res images will be scaled down.')
parser.add_argument('--shuffle', dest='shuffle', type=bool_t, default='True', help='Shuffle dataset')
parser.add_argument('--hf_token', type=str, default=None, required=False, help='A HuggingFace token is needed to download private models for training.')
parser.add_argument('--project_id', type=str, default='diffusers', help='Project ID for reporting to WandB')
parser.add_argument('--fp16', dest='fp16', type=bool_t, default='False', help='Train in mixed precision')
parser.add_argument('--image_log_steps', type=int, default=500, help='Number of steps to log images at.')
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
parser.add_argument('--image_log_inference_steps', type=int, default=50, help='Number of inference steps to use to log images.')
parser.add_argument('--image_log_scheduler', type=str, default="EulerDiscreteScheduler", help='Number of inference steps to use to log images.')
parser.add_argument('--clip_penultimate', type=bool_t, default='True', help='Use penultimate CLIP layer for text embedding')
parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits')
parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.")
parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention')
parser.add_argument('--wandb', dest='enablewandb', type=bool_t, default='False', help='Enable WeightsAndBiases Reporting')
parser.add_argument('--inference', dest='enableinference', type=bool_t, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)')
parser.add_argument('--extended_validation', type=bool_t, default='False', help='Perform extended validation of images to catch truncated or corrupt images.')
parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with <dataset_dirname>_cropped.')
parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
parser.add_argument('--extended_mode_chunks', type=int, default=3, help='Enables extended mode for tokenization with given amount of maximum chunks. Values < 2 disable.')
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
parser.add_argument(
"--snr_gamma",
type=float,
default=5.0,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
args = parser.parse_args()
def setup():
try :
torch.distributed.init_process_group("nccl", init_method="env://")
print('distributed training is ENABLED')
return True
except Exception :
print('distributed training is DISABLED')
return False
def cleanup():
torch.distributed.destroy_process_group()
def get_rank() -> int:
if not torch.distributed.is_initialized():
return 0
return torch.distributed.get_rank()
def get_world_size() -> int:
if not torch.distributed.is_initialized():
return 1
return torch.distributed.get_world_size()
def get_gpu_ram() -> str:
"""
Returns memory usage statistics for the CPU, GPU, and Torch.
:return:
"""
gpu_str = ""
torch_str = ""
try:
cudadev = torch.cuda.current_device()
nvml_device = pynvml.nvmlDeviceGetHandleByIndex(cudadev)
gpu_info = pynvml.nvmlDeviceGetMemoryInfo(nvml_device)
gpu_total = int(gpu_info.total / 1E6)
gpu_free = int(gpu_info.free / 1E6)
gpu_used = int(gpu_info.used / 1E6)
gpu_str = f"GPU: (U: {gpu_used:,}mb F: {gpu_free:,}mb " \
f"T: {gpu_total:,}mb) "
torch_reserved_gpu = int(torch.cuda.memory.memory_reserved() / 1E6)
torch_reserved_max = int(torch.cuda.memory.max_memory_reserved() / 1E6)
torch_used_gpu = int(torch.cuda.memory_allocated() / 1E6)
torch_max_used_gpu = int(torch.cuda.max_memory_allocated() / 1E6)
torch_str = f"TORCH: (R: {torch_reserved_gpu:,}mb/" \
f"{torch_reserved_max:,}mb, " \
f"A: {torch_used_gpu:,}mb/{torch_max_used_gpu:,}mb)"
except AssertionError:
pass
cpu_maxrss = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1E3 +
resource.getrusage(
resource.RUSAGE_CHILDREN).ru_maxrss / 1E3)
cpu_vmem = psutil.virtual_memory()
cpu_free = int(cpu_vmem.free / 1E6)
return f"CPU: (maxrss: {cpu_maxrss:,}mb F: {cpu_free:,}mb) " \
f"{gpu_str}" \
f"{torch_str}"
def _sort_by_ratio(bucket: tuple) -> float:
return bucket[0] / bucket[1]
def _sort_by_area(bucket: tuple) -> float:
return bucket[0] * bucket[1]
class Validation():
def __init__(self, is_skipped: bool, is_extended: bool) -> None:
if is_skipped:
self.validate = self.__no_op
return print("Validation: Skipped")
if is_extended:
self.validate = self.__extended_validate
return print("Validation: Extended")
self.validate = self.__validate
print("Validation: Standard")
def __validate(self, fp: str) -> bool:
try:
img = Image.open(fp)
[s, _] = os.path.splitext(fp)
return img is not None and os.path.exists(s + '.txt')
except:
print(f'WARNING: Image cannot be opened: {fp}')
return False
def __extended_validate(self, fp: str) -> bool:
try:
Image.open(fp).load()
return True
except (OSError) as error:
if 'truncated' in str(error):
print(f'WARNING: Image truncated: {error}')
return False
print(f'WARNING: Image cannot be opened: {error}')
return False
except:
print(f'WARNING: Image cannot be opened: {error}')
return False
def __no_op(self, fp: str) -> bool:
return True
class Resize():
def __init__(self, is_resizing: bool, is_not_migrating: bool) -> None:
if not is_resizing:
self.resize = self.__no_op
return
if not is_not_migrating:
self.resize = self.__migration
dataset_path = os.path.split(args.dataset)
self.__directory = os.path.join(
dataset_path[0],
f'{dataset_path[1]}_cropped'
)
os.makedirs(self.__directory, exist_ok=True)
return print(f"Resizing: Performing migration to '{self.__directory}'.")
self.resize = self.__no_migration
def __no_migration(self, image_path: str, w: int, h: int) -> Img:
return ImageOps.fit(
Image.open(image_path),
(w, h),
bleed=0.0,
centering=(0.5, 0.5),
method=Image.Resampling.LANCZOS
)
def __migration(self, image_path: str, w: int, h: int) -> Img:
filename = re.sub('\.[^/.]+$', '', os.path.split(image_path)[1])
image = ImageOps.fit(
Image.open(image_path),
(w, h),
bleed=0.0,
centering=(0.5, 0.5),
method=Image.Resampling.LANCZOS
)
image.save(
os.path.join(f'{self.__directory}', f'{filename}.jpg'),
optimize=True
)
try:
shutil.copy(
os.path.join(args.dataset, f'{filename}.txt'),
os.path.join(self.__directory, f'{filename}.txt'),
follow_symlinks=False
)
except (FileNotFoundError):
f = open(
os.path.join(self.__directory, f'{filename}.txt'),
'w',
encoding='UTF-8'
)
f.close()
return image
def __no_op(self, image_path: str, w: int, h: int) -> Img:
return Image.open(image_path)
class ImageStore:
def __init__(self, data_dirs: str) -> None:
self.data_dirs = data_dirs.split(',')
data_dir = data_dirs
for data_dir in self.data_dirs :
print('include', data_dir)
exts = ['jpg', 'jpeg', 'png', 'bmp', 'webp']
self.image_files = []
for data_dir in self.data_dirs :
print('listing files in', data_dir)
[self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)[:100000000]) for e in exts]
[self.image_files.extend(glob.glob(f'{data_dir}' + '/**/*.' + e)[:100000000]) for e in exts]
self.image_files = list(set(self.image_files))#[:1000]
# self.image_files = []
# #[self.image_files.extend(glob.glob(f'{data_dir}' + '/**/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']]
# for i in range(20) :
# folder = '%05d' % i
# [self.image_files.extend(glob.glob(f'{data_dir}' + f'/{folder}/*.' + e)) for e in ['jpg']]
# [self.image_files.extend(glob.glob(f'{data_dir}' + '/**/*.' + e)) for e in ['jpg']]
self.validator = Validation(
args.skip_validation,
args.extended_validation
).validate
self.resizer = Resize(args.resize, args.no_migration).resize
print(' -- before validation we have', len(self.image_files), 'images')
self.image_files = [x for x in self.image_files if self.validator(x)]
print(' -- after validation we have', len(self.image_files), 'images')
self.caption_cache = {}
def __len__(self) -> int:
return len(self.image_files)
# iterator returns images as PIL images and their index in the store
def entries_iterator(self) -> Generator[Tuple[Img, int], None, None]:
for f in range(len(self)):
yield Image.open(self.image_files[f]), f
def caption_iterator(self) -> Generator[Tuple[str, int, str], None, None]:
for f in range(len(self)):
filename, file_extension = os.path.splitext(self.image_files[f])
if os.path.exists(filename + '.txt') :
with open(filename + '.txt', 'r', encoding='UTF-8') as fp:
txt = fp.read()
else :
txt = ''
yield txt, f, self.image_files[f]
# get image by index
def get_image(self, ref: Tuple[int, int, int]) -> Img:
img = Image.open(self.image_files[ref[0]]).convert('RGB')
w, h = img.size
bucket_w = ref[1]
bucket_h = ref[2]
if w / h >= bucket_w / bucket_h :
# cut width
sheight = bucket_h
swidth = int(round(w / (h / bucket_h)))
crop_left = np.random.randint(0, swidth - bucket_w + 1)
crop_top = 0
else :
# cut height
swidth = bucket_w
sheight = int(round(h / (w / bucket_w)))
crop_left = 0
crop_top = np.random.randint(0, sheight - bucket_h + 1)
img2 = img.resize((swidth, sheight), resample = Image.Resampling.BICUBIC)
img3 = visionF.crop(img2, crop_top, crop_left, bucket_h, bucket_w)
del img
del img2
return img3, (h, w, crop_top, crop_left, bucket_h, bucket_w)
# gets caption by removing the extension from the filename and replacing it with .txt
def get_caption(self, ref: Tuple[int, int, int]) -> str:
if ref[0] in self.caption_cache :
return self.caption_cache[ref[0]]
else :
filename, file_extension = os.path.splitext(self.image_files[ref[0]])
try :
if os.path.exists(filename + '.txt') :
with open(filename + '.txt', 'r', encoding='UTF-8') as fp:
self.caption_cache[ref[0]] = fp.read()
else :
txt = ''
except Exception :
return ''
return self.caption_cache[ref[0]]
# ====================================== #
# Bucketing code stolen from hasuwoof: #
# https://github.com/hasuwoof/huskystack #
# ====================================== #
from dataclasses import dataclass
@dataclass
class TagFreqAdjust :
tag: str
adjustment: Optional[float]
forced_prob: Optional[float]
allow_decrement: Optional[bool] = False
class AspectBucket:
def __init__(self, store: ImageStore,
num_buckets: int,
batch_size: int,
bucket_side_min: int = 256,
bucket_side_max: int = 1536,
bucket_side_increment: int = 64,
max_image_area: int = 512 * 768,
max_ratio: float = 3,
freq_adjust: List[TagFreqAdjust] = []):
self.requested_bucket_count = num_buckets
self.bucket_length_min = bucket_side_min
self.bucket_length_max = bucket_side_max
self.bucket_increment = bucket_side_increment
self.max_image_area = max_image_area
self.batch_size = batch_size
self.total_dropped = 0
self.freq_adjust = freq_adjust
print(freq_adjust)
if max_ratio <= 0:
self.max_ratio = float('inf')
else:
self.max_ratio = max_ratio
self.store = store
self.buckets = []
self._bucket_ratios = []
self._bucket_interp = None
self.bucket_data: Dict[tuple, List[int]] = dict()
self.init_buckets()
self.total_images = 0
self.tag_freq_map = defaultdict(int)
self.tag_image_index_map = defaultdict(list)
self.freq_adjusted_image_store_indices = []
self.perform_freq_asjustment()
self.fill_buckets()
def perform_freq_asjustment(self) :
entries = self.store.caption_iterator()
print('performing tag frequency adjustment')
for caption, idx, image_filename in tqdm.tqdm(entries) :
self.total_images += 1
tags = caption.split(',')
for tag in tags :
self.tag_freq_map[tag] += 1
self.tag_image_index_map[tag].append(idx)
adjusted_image_indices = set()
all_image_indices = set(range(self.total_images))
for adjustment in self.freq_adjust :
tag = adjustment.tag
old_freq = self.tag_freq_map[adjustment.tag] / self.total_images
if adjustment.forced_prob is not None :
new_freq = adjustment.forced_prob
elif adjustment.adjustment is not None :
new_freq = adjustment * old_freq
image_diff = int((new_freq * self.total_images - self.tag_freq_map[adjustment.tag]) / (1 - new_freq))
image_indices = copy.deepcopy(self.tag_image_index_map[adjustment.tag])
if not image_indices :
print(' -- warn, empty', adjustment.tag)
continue
adjusted_image_indices.update(image_indices)
if image_diff > 0 :
extra_image_indices = np.random.choice(image_indices, image_diff, replace = True)
image_indices.extend(extra_image_indices)
elif image_diff < 0 and adjustment.allow_decrement :
np.random.shuffle(image_indices)
image_indices = image_indices[: image_diff]
print('Tag adjustment for', adjustment.tag, 'is from', self.tag_freq_map[adjustment.tag], 'with', image_diff, 'to', len(image_indices))
self.freq_adjusted_image_store_indices.extend(image_indices)
self.freq_adjusted_image_store_indices.extend(list(all_image_indices - adjusted_image_indices))
np.random.shuffle(self.freq_adjusted_image_store_indices)
def iterate_frequency_adjusted_images(self) -> Generator[Tuple[Img, int], None, None]:
for idx in range(len(self.freq_adjusted_image_store_indices)):
image_store_index = self.freq_adjusted_image_store_indices[idx]
yield Image.open(self.store.image_files[image_store_index]), image_store_index
def init_buckets(self):
possible_lengths = list(range(self.bucket_length_min, self.bucket_length_max + 1, self.bucket_increment))
possible_buckets = list((w, h) for w, h in itertools.product(possible_lengths, possible_lengths)
if w >= h and w * h <= self.max_image_area and w / h <= self.max_ratio)
buckets_by_ratio = {}
# group the buckets by their aspect ratios
for bucket in possible_buckets:
w, h = bucket
# use precision to avoid spooky floats messing up your day
ratio = '{:.4e}'.format(w / h)
if ratio not in buckets_by_ratio:
group = set()
buckets_by_ratio[ratio] = group
else:
group = buckets_by_ratio[ratio]
group.add(bucket)
# now we take the list of buckets we generated and pick the largest by area for each (the first sorted)
# then we put all of those in a list, sorted by the aspect ratio
# the square bucket (LxL) will be the first
unique_ratio_buckets = sorted([sorted(buckets, key=_sort_by_area)[-1]
for buckets in buckets_by_ratio.values()], key=_sort_by_ratio)
# how many buckets to create for each side of the distribution
bucket_count_each = int(np.clip((self.requested_bucket_count + 1) / 2, 1, len(unique_ratio_buckets)))
# we know that the requested_bucket_count must be an odd number, so the indices we calculate
# will include the square bucket and some linearly spaced buckets along the distribution
indices = {*np.linspace(0, len(unique_ratio_buckets) - 1, bucket_count_each, dtype=int)}
# make the buckets, make sure they are unique (to remove the duplicated square bucket), and sort them by ratio
# here we add the portrait buckets by reversing the dimensions of the landscape buckets we generated above
buckets = sorted({*(unique_ratio_buckets[i] for i in indices),
*(tuple(reversed(unique_ratio_buckets[i])) for i in indices)}, key=_sort_by_ratio)
self.buckets = buckets
# cache the bucket ratios and the interpolator that will be used for calculating the best bucket later
# the interpolator makes a 1d piecewise interpolation where the input (x-axis) is the bucket ratio,
# and the output is the bucket index in the self.buckets array
# to find the best fit we can just round that number to get the index
self._bucket_ratios = [w / h for w, h in buckets]
self._bucket_interp = interp1d(self._bucket_ratios, list(range(len(buckets))), assume_sorted=True,
fill_value=None)
for b in buckets:
self.bucket_data[b] = []
def get_batch_count(self):
return sum(len(b) // self.batch_size for b in self.bucket_data.values())
def get_bucket_info(self):
return json.dumps({ "buckets": self.buckets, "bucket_ratios": self._bucket_ratios })
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int, int]], None, None]:
"""
Generator that provides batches where the images in a batch fall on the same bucket
Each element generated will be:
(index, w, h)
where each image is an index into the dataset
:return:
"""
max_bucket_len = max(len(b) for b in self.bucket_data.values())
index_schedule = list(range(max_bucket_len))
random.shuffle(index_schedule)
bucket_len_table = {
b: len(self.bucket_data[b]) for b in self.buckets
}
bucket_schedule = []
for i, b in enumerate(self.buckets):
bucket_schedule.extend([i] * (bucket_len_table[b] // self.batch_size))
random.shuffle(bucket_schedule)
bucket_pos = {
b: 0 for b in self.buckets
}
total_generated_by_bucket = {
b: 0 for b in self.buckets
}
for bucket_index in bucket_schedule:
b = self.buckets[bucket_index]
i = bucket_pos[b]
bucket_len = bucket_len_table[b]
batch = []
while len(batch) != self.batch_size:
# advance in the schedule until we find an index that is contained in the bucket
k = index_schedule[i]
if k < bucket_len:
entry = self.bucket_data[b][k]
batch.append(entry)
i += 1
total_generated_by_bucket[b] += self.batch_size
bucket_pos[b] = i
yield [(idx, *b) for idx in batch]
def fill_buckets(self):
entries = self.iterate_frequency_adjusted_images()
total_dropped = 0
print('performing bucket construction')
for entry, index in tqdm.tqdm(entries, total=len(self.freq_adjusted_image_store_indices)):
if not self._process_entry(entry, index):
total_dropped += 1
for b, values in self.bucket_data.items():
# shuffle the entries for extra randomness and to make sure dropped elements are also random
random.shuffle(values)
# make sure the buckets have an exact number of elements for the batch
to_drop = len(values) % self.batch_size
self.bucket_data[b] = list(values[:len(values) - to_drop])
total_dropped += to_drop
self.total_dropped = total_dropped
def _process_entry(self, entry: Image.Image, index: int) -> bool:
aspect = entry.width / entry.height
if aspect > self.max_ratio or (1 / aspect) > self.max_ratio:
return False
best_bucket = self._bucket_interp(aspect)
if best_bucket is None:
return False
bucket = self.buckets[round(float(best_bucket))]
(bucket_w, bucket_h) = bucket
w, h = entry.size
if w / h >= bucket_w / bucket_h :
# cut width
sheight = bucket_h
swidth = int(round(w / (h / bucket_h)))
if (swidth - bucket_w) / bucket_w > 0.1 :
return False
else :
# cut height
swidth = bucket_w
sheight = int(round(h / (w / bucket_w)))
if (sheight - bucket_h) / bucket_h > 0.1 :
return False
self.bucket_data[bucket].append(index)
del entry
return True
class AspectBucketSampler(torch.utils.data.Sampler):
def __init__(self, bucket: AspectBucket, num_replicas: int = 1, rank: int = 0):
super().__init__(None)
self.bucket = bucket
self.num_replicas = num_replicas
self.rank = rank
def __iter__(self):
# subsample the bucket to only include the elements that are assigned to this rank
indices = self.bucket.get_batch_iterator()
indices = list(indices)[self.rank::self.num_replicas]
return iter(indices)
def __len__(self):
return self.bucket.get_batch_count() // self.num_replicas
def process_tags(tags: List[str], min_tags=24, max_tags=70):
tags = [tag.strip() for tag in tags]
if np.random.randint(0, 100) < 50 :
tags = [tag.replace('_', ' ') for tag in tags]
final_tags = []
for tag in tags :
if ':' in tag :
tag = tag.split(':')[-1]
final_tags.append(tag)
kept_tags = np.random.randint(min_tags, max_tags + 1)
np.random.shuffle(final_tags)
final_tags = final_tags[:kept_tags]
return final_tags, False
def expand_prefix(tag: str) -> str :
if ':' in tag :
if tag.startswith("art:"):
tag = 'artist' + tag[3:]
elif tag.startswith("copy:"):
tag = 'copyright' + tag[4:]
elif tag.startswith("char:"):
tag = 'character' + tag[4:]
elif tag.startswith("gen:"):
tag = 'general' + tag[3:]
return tag
else :
return tag
def strip_prefix(tag: str) -> str :
if ':' in tag :
return tag.split(':')[-1]
else :
return tag
def is_artist_or_character(tag):
return tag.startswith("character:") or tag.startswith("artist:")
class AspectDataset(torch.utils.data.Dataset):
def __init__(self, store: ImageStore, device: torch.device, tokenizer1, tokenizer2, ucg: float = 0.1):
self.store = store
self.device = device
self.ucg = ucg
self.tokenizer1 = tokenizer1
self.tokenizer2 = tokenizer2
def process_prompt(self, tags: str) -> Tuple[str, float] :
tags = tags.split(',')
tags = [expand_prefix(t) for t in tags]
tags, _ = process_tags(tags)
np.random.shuffle(tags)
return ','.join(tags), 1
def __len__(self):
return len(self.store)
def __getitem__(self, item: Tuple[int, int, int]):
return_dict = {'pixel_values': None, 'prompt': None, 'weight': None}
image_file, (h, w, crop_top, crop_left, bucket_h, bucket_w) = self.store.get_image(item)
image_content = np.array(image_file).astype(np.float32) / 255.0
image_content = torch.from_numpy(image_content.transpose(2, 0, 1))
image_content = 2.0 * image_content - 1.0
return_dict['pixel_values'] = image_content
return_dict['micro_condition'] = (h, w, crop_top, crop_left, bucket_h, bucket_w)
prompt, weight = self.process_prompt(self.store.get_caption(item))
return_dict['weight'] = weight
if random.random() > self.ucg:
pass
else:
prompt = ''
return_dict['prompt'] = prompt
return return_dict
def collate_fn(self, examples):
pixel_values = torch.stack([example['pixel_values'] for example in examples if example is not None])
pixel_values.to(memory_format=torch.contiguous_format).float()
max_length = self.tokenizer1.model_max_length
max_chunks = args.extended_mode_chunks
micro_condition = [example['micro_condition'] for example in examples]
prompts = [example['prompt'] for example in examples]
input_ids1 = [self.tokenizer1([prompt], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=(max_length * max_chunks) - (max_chunks * 2)).input_ids[0] for prompt in prompts if prompt is not None]
input_ids2 = [self.tokenizer2([prompt], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=(max_length * max_chunks) - (max_chunks * 2)).input_ids[0] for prompt in prompts if prompt is not None]
weights = torch.tensor([example['weight'] for example in examples], dtype = torch.float32)
return {
'pixel_values': pixel_values,
'prompts': prompts,
'input_ids1': input_ids1,
'input_ids2': input_ids2,
'weights': weights,
'micro_condition': micro_condition
}
def encode_prompts_small_clip(device, tokenizer, text_encoder, input_ids) :
if type(text_encoder) is torch.nn.parallel.DistributedDataParallel:
text_encoder = text_encoder.module
max_length = tokenizer.model_max_length
#max_chunks = args.extended_mode_chunks
#input_ids = [tokenizer([prompt], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=(max_length * max_chunks) - (max_chunks * 2)).input_ids[0] for prompt in prompts if prompt is not None]
args.clip_penultimate = True
layer_idx = -2 if args.clip_penultimate else -1
with torch.autocast('cuda', enabled=args.fp16):
max_standard_tokens = max_length - 2
max_chunks = args.extended_mode_chunks
max_len = np.ceil(max(len(x) for x in input_ids) / max_standard_tokens).astype(int).item() * max_standard_tokens
if max_len > max_standard_tokens:
for i, x in enumerate(input_ids):
if len(x) < max_len:
input_ids[i] = [*x, *np.full((max_len - len(x)), tokenizer.eos_token_id)]
batch_t = torch.tensor(input_ids)
chunks = [batch_t[:, i:i + max_standard_tokens] for i in range(0, max_len, max_standard_tokens)]
chunk_result = list(range(len(chunks)))
for i, chunk in enumerate(chunks):
chunk = torch.cat((torch.full((chunk.shape[0], 1), tokenizer.bos_token_id), chunk, torch.full((chunk.shape[0], 1), tokenizer.eos_token_id)), 1)
chunk_result[i] = text_encoder(chunk.to(device), output_hidden_states=True)['hidden_states'][layer_idx]
outs = torch.cat(chunk_result, dim=-2)
else:
for i, x in enumerate(input_ids):
input_ids[i] = [tokenizer.bos_token_id, *x, *np.full((tokenizer.model_max_length - len(x) - 1), tokenizer.eos_token_id)]
outs = text_encoder(torch.asarray(input_ids).to(device), output_hidden_states=True).hidden_states[layer_idx]
outs = torch.stack(tuple(outs))
return outs
def encode_prompts_big_clip(device, tokenizer, text_encoder, input_ids) :
if type(text_encoder) is torch.nn.parallel.DistributedDataParallel:
text_encoder = text_encoder.module
max_length = tokenizer.model_max_length
#max_chunks = args.extended_mode_chunks
#input_ids = [tokenizer([prompt], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=(max_length * max_chunks) - (max_chunks * 2)).input_ids[0] for prompt in prompts if prompt is not None]
args.clip_penultimate = True
layer_idx = -2 if args.clip_penultimate else -1
all_pool_outputs = []
all_pool_outputs_weights = []
with torch.autocast('cuda', enabled=args.fp16):
max_standard_tokens = max_length - 2
max_chunks = args.extended_mode_chunks
max_len = np.ceil(max(len(x) for x in input_ids) / max_standard_tokens).astype(int).item() * max_standard_tokens
per_sample_len = [len(x) for x in input_ids]
if max_len > max_standard_tokens:
for i, x in enumerate(input_ids):
if len(x) < max_len:
input_ids[i] = [*x, *np.full((max_len - len(x)), 0)]
batch_t = torch.tensor(input_ids)
chunks = [batch_t[:, i:i + max_standard_tokens] for i in range(0, max_len, max_standard_tokens)]
chunk_result = list(range(len(chunks)))
for i, chunk in enumerate(chunks):
sample_weight = []
chunk = torch.cat((torch.full((chunk.shape[0], 1), tokenizer.bos_token_id), chunk, torch.full((chunk.shape[0], 1), 0)), 1)
for j, sample_len in zip(range(chunk.shape[0]), per_sample_len) :
cur_chunk = i + 1
required_chunk = ((sample_len - 1) // max_standard_tokens) + 1
if cur_chunk == required_chunk :
last_valid_pos = sample_len % max_standard_tokens
if last_valid_pos == 0 and sample_len != 0 :
last_valid_pos = max_standard_tokens
chunk[j, last_valid_pos + 1] = tokenizer.eos_token_id
sample_weight.append(float(last_valid_pos + 2) / float(max_standard_tokens + 2))
elif cur_chunk < required_chunk :
sample_weight.append(1.0)
else :
sample_weight.append(2.0 / (max_standard_tokens + 2.0))
out_states = text_encoder(chunk.to(device), output_hidden_states=True)
text_embeds = out_states.text_embeds # pooled
all_pool_outputs.append(text_embeds)
all_pool_outputs_weights.append(torch.asarray(sample_weight).view(-1, 1).to(device).to(text_embeds.dtype))
chunk_result[i] = out_states.hidden_states[layer_idx]
outs = torch.cat(chunk_result, dim=-2)
else:
sample_weight = []
for i, x in enumerate(input_ids):
input_ids[i] = [tokenizer.bos_token_id, *x, tokenizer.eos_token_id, *np.full((tokenizer.model_max_length - len(x) - 2), 0)]
sample_weight.append(1.0)
out_states = text_encoder(torch.asarray(input_ids).to(device), output_hidden_states=True)
text_embeds = out_states.text_embeds # pooled
all_pool_outputs.append(text_embeds)
all_pool_outputs_weights.append(torch.asarray(sample_weight).view(-1, 1).to(device).to(text_embeds.dtype))
outs = out_states.hidden_states[layer_idx]
outs = torch.stack(tuple(outs))
all_pool_outputs = torch.stack(all_pool_outputs, dim = -1)
all_pool_outputs_weights = torch.stack(all_pool_outputs_weights, dim = -1)
pooled_output = (all_pool_outputs * all_pool_outputs_weights).sum(dim = -1) / all_pool_outputs_weights.sum(dim = -1)
#pooled_output = torch.stack(all_pool_outputs, dim = -1).mean(-1)
return outs, pooled_output
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
Exponential Moving Average of models weights
"""
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
parameters = list(parameters)
self.shadow_params = [p.clone().detach() for p in parameters]
self.decay = decay
self.optimization_step = 0
def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
value = (1 + optimization_step) / (10 + optimization_step)
return 1 - min(self.decay, value)
@torch.no_grad()
def step(self, parameters):
parameters = list(parameters)
self.optimization_step += 1
self.decay = self.get_decay(self.optimization_step)
for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
tmp = self.decay * (s_param - param)
s_param.sub_(tmp)
else:
s_param.copy_(param)
torch.cuda.empty_cache()
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Copy current averaged parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = list(parameters)
for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.data)
# From CompVis LitEMA implementation
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
del self.collected_params
gc.collect()
def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
Args:
device: like `device` argument to `torch.Tensor.to`
"""
# .to() on the tensors handles None correctly
self.shadow_params = [
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
for p in self.shadow_params
]
def _get_add_time_ids(unet, text_encoder_2: CLIPTextModelWithProjection, everything, dtype):
add_time_ids = list(everything)
if type(text_encoder_2) is torch.nn.parallel.DistributedDataParallel:
text_encoder_2 = text_encoder_2.module
if type(unet) is torch.nn.parallel.DistributedDataParallel:
unet = unet.module
passed_add_embed_dim = (
unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_2.config.projection_dim
)
expected_add_embed_dim = unet.add_embedding.linear_1.in_features
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids
def compute_snr(noise_scheduler, timesteps):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
# Compute SNR.
snr = (alpha / sigma) ** 2
return snr
def main(enabled_dis = True):
"""
TODO:
better image loader
gradient accumulation
//tag loss manager
//randomize tags
//sd-webui text encoding
prior perserving loss
inpainting objective
"""
rank = get_rank()
world_size = get_world_size()
torch.cuda.set_device(rank)
if rank == 0:
os.makedirs(args.output_path, exist_ok=True)
mode = 'disabled'
if args.enablewandb:
mode = 'online'
if args.hf_token is not None:
os.environ['HF_API_TOKEN'] = args.hf_token
args.hf_token = None
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb', mode=mode)
# Inform the user of host, and various versions -- useful for debugging issues.
print("RUN_NAME:", args.run_name)
print("HOST:", socket.gethostname())
print("CUDA:", torch.version.cuda)