From 34bfdcf1ea8fc1858a3de5eb31e3d1e873f8e9db Mon Sep 17 00:00:00 2001 From: Yzy <2154597198@qq.com> Date: Tue, 31 Oct 2023 16:24:48 +0800 Subject: [PATCH] [Modify] Append map_location for load different device model to another type device --- src/gnnwr/models.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/gnnwr/models.py b/src/gnnwr/models.py index 9610bca..63b5237 100644 --- a/src/gnnwr/models.py +++ b/src/gnnwr/models.py @@ -426,18 +426,20 @@ def predict_weight(self, dataset): result = torch.cat((result, weight), 0) result = result.cpu().detach().numpy() return result - def load_model(self, path, use_dict=False): + + def load_model(self, path, use_dict=False, map_location=None): """ load model from the path :param path: the path of the model :param use_dict: whether use dict to load the model + :param map_location: map location """ if use_dict: - data = torch.load(path).state_dict() + data = torch.load(path, map_location=map_location).state_dict() self._model.load_state_dict(data) else: - self._model = torch.load(path) + self._model = torch.load(path, map_location=map_location) self.__istrained = True def getLoss(self):