diff --git a/kan/spline.py b/kan/spline.py index 48e22508..ba23115c 100644 --- a/kan/spline.py +++ b/kan/spline.py @@ -135,6 +135,6 @@ def curve2coef(x_eval, y_eval, grid, k, device="cpu"): # x_eval: (size, batch); y_eval: (size, batch); grid: (size, grid); k: scalar mat = B_batch(x_eval, grid, k, device=device).permute(0, 2, 1) # coef = torch.linalg.lstsq(mat, y_eval.unsqueeze(dim=2)).solution[:, :, 0] - coef = torch.linalg.lstsq(mat.to(device), y_eval.unsqueeze(dim=2).to(device), - driver='gelsy' if device == 'cpu' else 'gels').solution[:, :, 0] + coef = torch.linalg.lstsq(mat.to('cpu'), y_eval.unsqueeze(dim=2).to('cpu'), + driver='gelsy').solution[:, :, 0] return coef.to(device)