Skip to content

Commit

Permalink
Improve cubic interpolation (#204)
Browse files Browse the repository at this point in the history
  • Loading branch information
msluszniak authored Nov 2, 2023
1 parent 3793cd5 commit 5ad9e5a
Showing 1 changed file with 47 additions and 5 deletions.
52 changes: 47 additions & 5 deletions lib/scholar/interpolation/cubic_spline.ex
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ defmodule Scholar.Interpolation.CubicSpline do
%Scholar.Interpolation.CubicSpline{
coefficients: Nx.tensor(
[
[0.0, 1.500000238418579, -3.500000238418579, 2.0],
[0.0, 1.5, -0.4999999403953552, 0.0]
[0.0, 1.5, -3.5, 2.0],
[0.0, 1.5, -0.5, 0.0]
]
),
x: Nx.tensor(
Expand Down Expand Up @@ -110,7 +110,7 @@ defmodule Scholar.Interpolation.CubicSpline do

b = Nx.stack([2 * slope[0], 3 * (dx[0] * slope[1] + dx[1] * slope[0]), 2 * slope[1]])

Nx.LinAlg.solve(a, b)
tridiagonal_solve(a, b)

{_, :not_a_knot} ->
up_diag =
Expand Down Expand Up @@ -151,7 +151,7 @@ defmodule Scholar.Interpolation.CubicSpline do
Nx.new_axis(b_n, 0)
])

Nx.LinAlg.solve(a, b)
tridiagonal_solve(a, b)

_ ->
up_diag =
Expand Down Expand Up @@ -190,7 +190,7 @@ defmodule Scholar.Interpolation.CubicSpline do
Nx.new_axis(b_n, 0)
])

Nx.LinAlg.solve(a, b)
tridiagonal_solve(a, b)
end

t = (s[0..-2//1] + s[1..-1//1] - 2 * slope) / dx
Expand Down Expand Up @@ -291,4 +291,46 @@ defmodule Scholar.Interpolation.CubicSpline do

Nx.reshape(result, original_shape)
end

defnp tridiagonal_solve(a, b) do
n = Nx.size(b)
w = Nx.broadcast(0, {n - 1})
p = g = Nx.broadcast(0, {n})
i = Nx.take_diagonal(a, offset: -1)
j = Nx.take_diagonal(a)
k = Nx.take_diagonal(a, offset: 1)

w_0 = k[0] / j[0]
g_0 = b[0] / j[0]
w = Nx.indexed_put(w, Nx.new_axis(0, 0), w_0)
g = Nx.indexed_put(g, Nx.new_axis(0, 0), g_0)

{{w, g}, _} =
while {{w, g}, {index = 1, i, j, k, b}}, index < n do
w =
if index < n - 1 do
w_i = k[index] / (j[index] - i[index - 1] * w[index - 1])
Nx.indexed_put(w, Nx.new_axis(index, 0), w_i)
else
w
end

g_i = (b[index] - i[index - 1] * g[index - 1]) / (j[index] - i[index - 1] * w[index - 1])
g = Nx.indexed_put(g, Nx.new_axis(index, 0), g_i)

{{w, g}, {index + 1, i, j, k, b}}
end

p = Nx.indexed_put(p, Nx.new_axis(n - 1, 0), g[n - 1])

{p, _} =
while {p, {index = n - 1, g, w}}, index > 0 do
p_i = g[index - 1] - w[index - 1] * p[index]
p = Nx.indexed_put(p, Nx.new_axis(index - 1, 0), p_i)

{p, {index - 1, g, w}}
end

p
end
end

0 comments on commit 5ad9e5a

Please sign in to comment.