Skip to content

Commit

Permalink
[Modify] Replace __test() by evaluate() and fixed the r2 of train and…
Browse files Browse the repository at this point in the history
… valid
  • Loading branch information
Y-nuclear committed Jan 9, 2025
1 parent d9fddb7 commit e7b9e92
Showing 1 changed file with 47 additions and 23 deletions.
70 changes: 47 additions & 23 deletions src/gnnwr/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,15 +393,15 @@ def __valid(self):
else:
self._noUpdateEpoch += 1

def __test(self):
def __evaluate(self, dataset):
"""
test the network
"""
self._model.eval()
test_loss = 0
out_list = torch.tensor([],dtype=torch.float32,device=self._device)
label_list = torch.tensor([],dtype=torch.float32,device=self._device)
data_loader = self._test_dataset.dataloader
data_loader = dataset.dataloader # dataset
x_data = torch.tensor([],dtype=torch.float32,device=self._device)
y_data = torch.tensor([],dtype=torch.float32,device=self._device)
y_pred = torch.tensor([],dtype=torch.float32,device=self._device)
Expand All @@ -424,11 +424,10 @@ def __test(self):
test_loss += loss.item() * data[0].size(0)
else:
test_loss += loss.item() * data.size(0) # accumulate the loss
test_loss /= len(self._test_dataset)
self.__testLoss = test_loss
self.__testr2 = 1 - torch.sum((out_list - label_list) ** 2) / torch.sum((label_list - torch.mean(label_list)) ** 2)
self._test_diagnosis = DIAGNOSIS(weight_all, x_data, y_data, y_pred)
return self._test_diagnosis.R2().data

test_loss /= len(dataset)

return test_loss, DIAGNOSIS(weight_all, x_data, y_data, y_pred)

def run(self, max_epoch=1, early_stop=-1,**kwargs):
"""
Expand Down Expand Up @@ -602,10 +601,10 @@ def load_model(self, path, use_dict=False, map_location=None):
the location can be ``"cpu"`` or ``"cuda"``
"""
if use_dict:
data = torch.load(path, map_location=map_location)
data = torch.load(path, map_location=map_location, weights_only=False)
self._model.load_state_dict(data)
else:
self._model = torch.load(path, map_location=map_location)
self._model = torch.load(path, map_location=map_location, weights_only=False)
if self._use_gpu:
self._model = self._model.cuda()
self._out = self._out.cuda()
Expand All @@ -631,9 +630,9 @@ def gpumodel_to_cpu(self, path, save_path, use_model=True):
whether use dict to load the model (default: ``True``)
"""
if use_model:
data = torch.load(path, map_location='cpu').state_dict()
data = torch.load(path, map_location='cpu', weights_only=False).state_dict()
else:
data = torch.load(path, map_location='cpu')
data = torch.load(path, map_location='cpu', weights_only=False)
new_state_dict = OrderedDict()
for k, v in data.items():
name = k[7:] # remove module.
Expand Down Expand Up @@ -688,10 +687,10 @@ def result(self, path=None, use_dict=False, map_location=None):
if path is None:
path = self._modelSavePath + "/" + self._modelName + ".pkl"
if use_dict:
data = torch.load(path, map_location=map_location)
data = torch.load(path, map_location=map_location, weights_only=False)
self._model.load_state_dict(data)
else:
self._model = torch.load(path, map_location=map_location)
self._model = torch.load(path, map_location=map_location, weights_only=False)
if self._use_gpu:
self._model = nn.DataParallel(module=self._model) # parallel computing
self._model = self._model.cuda()
Expand All @@ -700,7 +699,13 @@ def result(self, path=None, use_dict=False, map_location=None):
self._model = self._model.cpu()
self._out = self._out.cpu()
with torch.no_grad():
self.__test()
_ , self._train_diagnosis = self.__evaluate(self._train_dataset)
self._trainr2 = self._train_diagnosis.R2().data
_ , self._valid_diagnosis = self.__evaluate(self._valid_dataset)
self._validr2 = self._valid_diagnosis.R2().data
self.__testLoss, self._test_diagnosis = self.__evaluate(self._test_dataset)
self.__testr2 = self._test_diagnosis.R2().data


logging.info("Test Loss: " + str(self.__testLoss) + "; Test R2: " + str(self.__testr2))
# print result
Expand All @@ -719,9 +724,8 @@ def result(self, path=None, use_dict=False, map_location=None):
print("\n--------------------Result Information----------------")
print("Test Loss: | {:>25.5f}".format(self.__testLoss))
print("Test R2 : | {:>25.5f}".format(self.__testr2))
if self._besttrainr2 is not None and self._besttrainr2 != float('-inf'):
print("Train R2 : | {:>25.5f}".format(self._besttrainr2))
print("Valid R2 : | {:>25.5f}".format(self._bestr2))
print("Train R2 : | {:>25.5f}".format(self._trainr2))
print("Valid R2 : | {:>25.5f}".format(self._validr2))
print("RMSE: | {:>30.5f}".format(self._test_diagnosis.RMSE().data))
print("AIC: | {:>30.5f}".format(self._test_diagnosis.AIC()))
print("AICc: | {:>30.5f}".format(self._test_diagnosis.AICc()))
Expand Down Expand Up @@ -763,42 +767,49 @@ def reg_result(self, filename=None, model_path=None, use_dict=False, only_return
model_path = self._modelSavePath + "/" + self._modelName + ".pkl"

if use_dict:
data = torch.load(model_path, map_location=map_location)
data = torch.load(model_path, map_location=map_location, weights_only=False)
self._model.load_state_dict(data)
else:
self._model = torch.load(model_path, map_location=map_location)
self._model = torch.load(model_path, map_location=map_location, weights_only=False)

if self._use_gpu:
self._model = nn.DataParallel(module=self._model)
self._model = self._model.cuda()
self._out = self._out.cuda()
self._model,self._out = self._model.cuda(),self._out.cuda()
else:
self._model = self._model.cpu()
self._out = self._out.cpu()
self._model, self._out = self._model.cpu(), self._out.cpu()

device = torch.device('cuda') if self._use_gpu else torch.device('cpu')
result = torch.tensor([]).to(torch.float32).to(device)
train_data_size = valid_data_size = 0

with torch.no_grad():
# calculate the result of train dataset
for data, coef, label, data_index in self._train_dataset.dataloader:
data, coef, label, data_index = data.to(device), coef.to(device), label.to(device), data_index.to(
device)
output = self._out(self._model(data).mul(coef.to(torch.float32)))
coefficient = self._model(data).mul(torch.tensor(self._coefficient).to(torch.float32).to(device))
output = torch.cat((coefficient, output, data_index), dim=1)
result = torch.cat((result, output), 0)
train_data_size = len(result)
# calculate the result of train dataset
for data, coef, label, data_index in self._valid_dataset.dataloader:
data, coef, label, data_index = data.to(device), coef.to(device), label.to(device), data_index.to(
device)
output = self._out(self._model(data).mul(coef.to(torch.float32)))
coefficient = self._model(data).mul(torch.tensor(self._coefficient).to(torch.float32).to(device))
output = torch.cat((coefficient, output, data_index), dim=1)
result = torch.cat((result, output), 0)
valid_data_size = len(result) - train_data_size
# calculate the result of train dataset
for data, coef, label, data_index in self._test_dataset.dataloader:
data, coef, label, data_index = data.to(device), coef.to(device), label.to(device), data_index.to(
device)
output = self._out(self._model(data).mul(coef.to(torch.float32)))
coefficient = self._model(data).mul(torch.tensor(self._coefficient).to(torch.float32).to(device))
output = torch.cat((coefficient, output, data_index), dim=1)
result = torch.cat((result, output), 0)

result = result.cpu().detach().numpy()
columns = list(self._train_dataset.x)
for i in range(len(columns)):
Expand All @@ -808,8 +819,21 @@ def reg_result(self, filename=None, model_path=None, use_dict=False, only_return
result = pd.DataFrame(result, columns=columns)
result[self._train_dataset.id] = result[self._train_dataset.id].astype(np.int32)
result["Pred_" + self._train_dataset.y[0]] = result["Pred_" + self._train_dataset.y[0]].astype(np.float32)

# set dataset belong to postprocess
result["dataset_belong"] = 'test'
result.loc[:train_data_size,"dataset_belong"] = 'train'
result.loc[train_data_size:valid_data_size,"dataset_belong"] = 'valid'

# denormalize pred result
if self._train_dataset.y_scale_info:
_, result['denormalized_pred_result'] = self._train_dataset.rescale(None,result)
else:
result['denormalized_pred_result'] = result["Pred_" + self._train_dataset.y[0]]

if only_return:
return result

if filename is not None:
result.to_csv(filename, index=False)
else:
Expand Down

0 comments on commit e7b9e92

Please sign in to comment.