diff --git a/src/gnnwr/models.py b/src/gnnwr/models.py index 9610bca..63b5237 100644 --- a/src/gnnwr/models.py +++ b/src/gnnwr/models.py @@ -426,18 +426,20 @@ def predict_weight(self, dataset): result = torch.cat((result, weight), 0) result = result.cpu().detach().numpy() return result - def load_model(self, path, use_dict=False): + + def load_model(self, path, use_dict=False, map_location=None): """ load model from the path :param path: the path of the model :param use_dict: whether use dict to load the model + :param map_location: map location """ if use_dict: - data = torch.load(path).state_dict() + data = torch.load(path, map_location=map_location).state_dict() self._model.load_state_dict(data) else: - self._model = torch.load(path) + self._model = torch.load(path, map_location=map_location) self.__istrained = True def getLoss(self):