Skip to content

Commit

Permalink
[Modify] Append map_location for load different device model to anoth…
Browse files Browse the repository at this point in the history
…er type device
  • Loading branch information
Y-nuclear committed Oct 31, 2023
1 parent cfd60bf commit 34bfdcf
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/gnnwr/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,18 +426,20 @@ def predict_weight(self, dataset):
result = torch.cat((result, weight), 0)
result = result.cpu().detach().numpy()
return result
def load_model(self, path, use_dict=False):

def load_model(self, path, use_dict=False, map_location=None):
"""
load model from the path
:param path: the path of the model
:param use_dict: whether use dict to load the model
:param map_location: map location
"""
if use_dict:
data = torch.load(path).state_dict()
data = torch.load(path, map_location=map_location).state_dict()
self._model.load_state_dict(data)
else:
self._model = torch.load(path)
self._model = torch.load(path, map_location=map_location)
self.__istrained = True

def getLoss(self):
Expand Down

0 comments on commit 34bfdcf

Please sign in to comment.