Skip to content

Commit

Permalink
load model with weights only specified for newer versions of pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
Abe404 committed Nov 18, 2024
1 parent e8f377f commit fbb2846
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions trainer/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,34 @@ def get_latest_model_paths(model_dir, k):
fpaths = [os.path.join(model_dir, f) for f in fnames]
return fpaths

def load_model(model_path):

def load_model_weights_only(model_path):
# only works with newer versions of pytorch
model = UNetGNRes()
try:
model.load_state_dict(torch.load(model_path, device))
model.load_state_dict(torch.load(model_path, device, weights_only=True))
model = torch.nn.DataParallel(model)
except:
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(model_path, device))
model.load_state_dict(torch.load(model_path, device, weights_only=True))
model.to(device)
return model


def load_model(model_path):
try:
model = load_model_weights_only(model_path)
except:
model = UNetGNRes()
try:
model.load_state_dict(torch.load(model_path, device))
model = torch.nn.DataParallel(model)
except:
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(model_path, device))
model.to(device)
return model

def create_first_model_with_random_weights(model_dir):
# used when no model was specified on project creation.
model_num = 1
Expand Down

0 comments on commit fbb2846

Please sign in to comment.