diff --git a/RealESRGAN/model.py b/RealESRGAN/model.py index 3a04daf..0dd434e 100644 --- a/RealESRGAN/model.py +++ b/RealESRGAN/model.py @@ -46,7 +46,7 @@ def load_weights(self, model_path, download=True): cached_download(config_file_url, cache_dir=cache_dir, force_filename=local_filename) print('Weights downloaded to:', os.path.join(cache_dir, local_filename)) - loadnet = torch.load(model_path) + loadnet = torch.load(model_path, weights_only=True) if 'params' in loadnet: self.model.load_state_dict(loadnet['params'], strict=True) elif 'params_ema' in loadnet: @@ -56,7 +56,7 @@ def load_weights(self, model_path, download=True): self.model.eval() self.model.to(self.device) - @torch.cuda.amp.autocast() + @torch.amp.autocast(device_type='cuda') def predict(self, lr_image, batch_size=4, patches_size=192, padding=24, pad_size=15): scale = self.scale