From 5ad9e5a5319c43db0e1ae036eb4e3dd51c2514ea Mon Sep 17 00:00:00 2001 From: Mateusz Sluszniak <56299341+msluszniak@users.noreply.github.com> Date: Thu, 2 Nov 2023 19:33:01 +0100 Subject: [PATCH] Improve cubic interpolation (#204) --- lib/scholar/interpolation/cubic_spline.ex | 52 ++++++++++++++++++++--- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/lib/scholar/interpolation/cubic_spline.ex b/lib/scholar/interpolation/cubic_spline.ex index d1370b46..00901f00 100644 --- a/lib/scholar/interpolation/cubic_spline.ex +++ b/lib/scholar/interpolation/cubic_spline.ex @@ -52,8 +52,8 @@ defmodule Scholar.Interpolation.CubicSpline do %Scholar.Interpolation.CubicSpline{ coefficients: Nx.tensor( [ - [0.0, 1.500000238418579, -3.500000238418579, 2.0], - [0.0, 1.5, -0.4999999403953552, 0.0] + [0.0, 1.5, -3.5, 2.0], + [0.0, 1.5, -0.5, 0.0] ] ), x: Nx.tensor( @@ -110,7 +110,7 @@ defmodule Scholar.Interpolation.CubicSpline do b = Nx.stack([2 * slope[0], 3 * (dx[0] * slope[1] + dx[1] * slope[0]), 2 * slope[1]]) - Nx.LinAlg.solve(a, b) + tridiagonal_solve(a, b) {_, :not_a_knot} -> up_diag = @@ -151,7 +151,7 @@ defmodule Scholar.Interpolation.CubicSpline do Nx.new_axis(b_n, 0) ]) - Nx.LinAlg.solve(a, b) + tridiagonal_solve(a, b) _ -> up_diag = @@ -190,7 +190,7 @@ defmodule Scholar.Interpolation.CubicSpline do Nx.new_axis(b_n, 0) ]) - Nx.LinAlg.solve(a, b) + tridiagonal_solve(a, b) end t = (s[0..-2//1] + s[1..-1//1] - 2 * slope) / dx @@ -291,4 +291,46 @@ defmodule Scholar.Interpolation.CubicSpline do Nx.reshape(result, original_shape) end + + defnp tridiagonal_solve(a, b) do + n = Nx.size(b) + w = Nx.broadcast(0, {n - 1}) + p = g = Nx.broadcast(0, {n}) + i = Nx.take_diagonal(a, offset: -1) + j = Nx.take_diagonal(a) + k = Nx.take_diagonal(a, offset: 1) + + w_0 = k[0] / j[0] + g_0 = b[0] / j[0] + w = Nx.indexed_put(w, Nx.new_axis(0, 0), w_0) + g = Nx.indexed_put(g, Nx.new_axis(0, 0), g_0) + + {{w, g}, _} = + while {{w, g}, {index = 1, i, j, k, b}}, index < n do + w = + if index < n - 1 do + w_i = k[index] / (j[index] - i[index - 1] * w[index - 1]) + Nx.indexed_put(w, Nx.new_axis(index, 0), w_i) + else + w + end + + g_i = (b[index] - i[index - 1] * g[index - 1]) / (j[index] - i[index - 1] * w[index - 1]) + g = Nx.indexed_put(g, Nx.new_axis(index, 0), g_i) + + {{w, g}, {index + 1, i, j, k, b}} + end + + p = Nx.indexed_put(p, Nx.new_axis(n - 1, 0), g[n - 1]) + + {p, _} = + while {p, {index = n - 1, g, w}}, index > 0 do + p_i = g[index - 1] - w[index - 1] * p[index] + p = Nx.indexed_put(p, Nx.new_axis(index - 1, 0), p_i) + + {p, {index - 1, g, w}} + end + + p + end end