diff --git a/trainer/model_utils.py b/trainer/model_utils.py index 555b1df..43f405a 100644 --- a/trainer/model_utils.py +++ b/trainer/model_utils.py @@ -44,17 +44,34 @@ def get_latest_model_paths(model_dir, k): fpaths = [os.path.join(model_dir, f) for f in fnames] return fpaths -def load_model(model_path): + +def load_model_weights_only(model_path): + # only works with newer versions of pytorch model = UNetGNRes() try: - model.load_state_dict(torch.load(model_path, device)) + model.load_state_dict(torch.load(model_path, device, weights_only=True)) model = torch.nn.DataParallel(model) except: model = torch.nn.DataParallel(model) - model.load_state_dict(torch.load(model_path, device)) + model.load_state_dict(torch.load(model_path, device, weights_only=True)) model.to(device) return model + +def load_model(model_path): + try: + model = load_model_weights_only(model_path) + except: + model = UNetGNRes() + try: + model.load_state_dict(torch.load(model_path, device)) + model = torch.nn.DataParallel(model) + except: + model = torch.nn.DataParallel(model) + model.load_state_dict(torch.load(model_path, device)) + model.to(device) + return model + def create_first_model_with_random_weights(model_dir): # used when no model was specified on project creation. model_num = 1