Skip to content

Commit

Permalink
#45
Browse files Browse the repository at this point in the history
  • Loading branch information
maikherbig authored Jan 4, 2023
1 parent e13f34f commit 6748fe4
Showing 1 changed file with 57 additions and 19 deletions.
76 changes: 57 additions & 19 deletions AIDeveloper/aid_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
---------
@author: maikherbig
"""
VERSION = "0.4.5" #Python 3.9.9 Version
VERSION = "0.4.7" #Python 3.9.9 Version

import os,sys,gc
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
Expand Down Expand Up @@ -5317,8 +5317,10 @@ def update_para_dict():
if len(xtra_in)>1:# False and True is present. Not supported
print("Xtra data is used only for some files. Xtra data needs to be used either by all or by none!")
return
xtra_in = list(xtra_in)[0]#this is either True or False

try:
xtra_in = list(xtra_in)[0] #this is either True or False
except:
pass
############Cropping#####################
X_valid,y_valid,Indices,xtra_valid = [],[],[],[]
for i in range(len(SelectedFiles_valid)):
Expand Down Expand Up @@ -5371,11 +5373,16 @@ def update_para_dict():
# msg.setWindowTitle("Export is turned off!")
# msg.setStandardButtons(QtWidgets.QMessageBox.Ok)
# msg.exec_()

X_valid = np.concatenate(X_valid)
y_valid = np.concatenate(y_valid)
if len(X_valid)>0:
X_valid = np.concatenate(X_valid)
y_valid = np.concatenate(y_valid)
xtra_valid = np.concatenate(xtra_valid)
else:
X_valid = np.array(X_valid)
y_valid = np.array(y_valid)
xtra_valid = np.array(xtra_valid)

Y_valid = to_categorical(y_valid, nr_classes)# * 2 - 1
xtra_valid = np.concatenate(xtra_valid)
if not bool(self.actionExport_Off.isChecked())==True:
#Save the labels
np.savetxt(new_modelname.split(".model")[0]+'_Valid_Labels.txt',y_valid.astype(int),fmt='%i')
Expand All @@ -5385,6 +5392,7 @@ def update_para_dict():
elif len(X_valid.shape)==3:
channels=1
else:
channels = np.nan
print("Invalid data dimension:" +str(X_valid.shape))
if channels==1:
#Add the "channels" dimension
Expand All @@ -5400,16 +5408,22 @@ def update_para_dict():
#Validation data can be cropped to final size already since no augmentation
#will happen on this data set
dim_val = X_valid.shape
print("Current dim. of validation set (pixels x pixels) = "+str(dim_val[2]))
if dim_val[2]!=crop:
print("Change dim. (pixels x pixels) of validation set to = "+str(crop))
remove = int(dim_val[2]/2.0 - crop/2.0)
X_valid = X_valid[:,remove:remove+crop,remove:remove+crop,:] #crop to crop x crop pixels #TensorFlow
print("Current dim. of validation set = "+str(dim_val))
if len(dim_val)>1:
if dim_val[2]!=crop:
print("Change dim. (pixels x pixels) of validation set to = "+str(crop))
remove = int(dim_val[2]/2.0 - crop/2.0)
X_valid = X_valid[:,remove:remove+crop,remove:remove+crop,:] #crop to crop x crop pixels #TensorFlow

if xtra_in==True:
print("Add Xtra Data to X_valid")
X_valid = [X_valid,xtra_valid]

# copy the validation set
X_valid_orig = np.copy(X_valid)
y_valid_orig = np.copy(y_valid)
Y_valid_orig = np.copy(Y_valid)


####################Update the PopupFitting########################
self.fittingpopups_ui[listindex].lineEdit_modelname_pop.setText(new_modelname) #set the progress bar to zero
Expand All @@ -5430,6 +5444,10 @@ def update_para_dict():
self.fittingpopups_ui[listindex].comboBox_paddingMode_pop.setCurrentIndex(index)
#zoom_order
self.fittingpopups_ui[listindex].comboBox_zoomOrder.setCurrentIndex(zoom_order)
# Validation split
self.fittingpopups_ui[listindex].checkBox_validationSplit.setChecked(self.checkBox_validationSplit.isChecked())
self.fittingpopups_ui[listindex].doubleSpinBox_validationSplit.setValue(self.doubleSpinBox_validationSplit.value())

#CPU setting
self.fittingpopups_ui[listindex].comboBox_cpu_pop.addItem("Default CPU")
if gpu_used==False:
Expand Down Expand Up @@ -5549,11 +5567,9 @@ def update_para_dict():
self.fittingpopups_ui[listindex].checkBox_lossW.setChecked(lossW_expert_on)
self.fittingpopups_ui[listindex].pushButton_lossW.setEnabled(lossW_expert_on)
self.fittingpopups_ui[listindex].lineEdit_lossW.setText(str(lossW_expert))

if channels==1:
channel_text = "Grayscale"
elif channels==3:
channel_text = "RGB"


channel_text = self.get_color_mode()
self.fittingpopups_ui[listindex].comboBox_colorMode_pop.addItems([channel_text])

###############Continue with training data:augmentation############
Expand Down Expand Up @@ -5638,7 +5654,29 @@ def update_para_dict():
if channels==1:
#Add the "channels" dimension
X_train = np.expand_dims(X_train,3)


if self.fittingpopups_ui[listindex].checkBox_validationSplit.isChecked():
# Train-Test split (note that training set can be a random batch every epoch. Same is true for the validation set)
perc_valid = self.fittingpopups_ui[listindex].doubleSpinBox_validationSplit.value()
ind_rand = rand_state.choice(a=range(X_train.shape[0]),size=int(np.ceil((perc_valid/100)*X_train.shape[0])),replace=False)
X_temp = X_train[ind_rand]
dim_val = X_temp.shape
if dim_val[2]!=crop:
remove = int(dim_val[2]/2.0 - crop/2.0)
X_temp = X_temp[:,remove:remove+crop,remove:remove+crop,:] #crop to crop x crop pixels #TensorFlow
if X_valid_orig.shape[0]>0: # there is another validation set, manually selected by user
X_valid = np.r_[X_valid_orig,X_temp]
y_valid = np.r_[y_valid_orig,y_train[ind_rand]]
else:
X_valid = X_temp
y_valid = y_train[ind_rand]

Y_valid = to_categorical(y_valid, nr_classes)

X_train = np.delete(X_train,ind_rand,axis=0)
y_train = np.delete(y_train,ind_rand,axis=0)


t3 = time.perf_counter()
#Some parallellization: use nr_threads (number of CPUs)
nr_threads = 1 #Somehow for MNIST and CIFAR, processing always took longer for nr_threads>1 . I tried nr_threads=2,4,8,16,24
Expand Down Expand Up @@ -6118,7 +6156,7 @@ def imgaug_worker(aug_paras,progress_callback,history_callback):
###################################################
###############Actual fitting######################
###################################################

if collection==False:
if model_keras_p == None:
history = model_keras.fit(X_batch, Y_batch,
Expand Down

0 comments on commit 6748fe4

Please sign in to comment.