Skip to content

Commit

Permalink
Merge pull request #21 from sisl/static2
Browse files Browse the repository at this point in the history
switched to reinterpreting a matrix for performance
  • Loading branch information
zsunberg authored Mar 19, 2018
2 parents 0e6c588 + 0a0dfb5 commit 2eb4e4e
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 43 deletions.
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
julia 0.6
StaticArrays 0.5.1
88 changes: 49 additions & 39 deletions src/GridInterpolations.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
__precompile__()
module GridInterpolations

using StaticArrays

export AbstractGrid, RectangleGrid, SimplexGrid, dimensions, length, label, ind2x, ind2x!, interpolate, maskedInterpolate, interpolants, vertices

abstract type AbstractGrid end
abstract type AbstractGrid{D} end # D is the dimension

mutable struct RectangleGrid <: AbstractGrid
mutable struct RectangleGrid{D} <: AbstractGrid{D}
cutPoints::Vector{Vector{Float64}}
cut_counts::Vector{Int}
cuts::Vector{Float64}
Expand All @@ -15,20 +16,21 @@ mutable struct RectangleGrid <: AbstractGrid
index2::Vector{Int}
weight2::Vector{Float64}

function RectangleGrid(cutPoints...)
function RectangleGrid{D}(cutPoints...) where D
cut_counts = Int[length(cutPoints[i]) for i = 1:length(cutPoints)]
cuts = vcat(cutPoints...)
myCutPoints = Array{Vector{Float64}}(length(cutPoints))
for i = 1:length(cutPoints)
if length(Set(cutPoints[i])) != length(cutPoints[i])
error(@sprintf("Duplicates cutpoints are not allowed (duplicates observed in dimension %d)",i))
end
if !issorted(cutPoints[i])
error("Cut points must be sorted")
end
numDims = length(cutPoints)
@assert numDims == D
for i = 1:numDims
if length(Set(cutPoints[i])) != length(cutPoints[i])
error(@sprintf("Duplicates cutpoints are not allowed (duplicates observed in dimension %d)",i))
end
if !issorted(cutPoints[i])
error("Cut points must be sorted")
end
myCutPoints[i] = cutPoints[i]
end
numDims = length(cutPoints)
index = zeros(Int, 2^numDims)
weight = zeros(Float64, 2^numDims)
index[1] = 1
Expand All @@ -37,11 +39,13 @@ mutable struct RectangleGrid <: AbstractGrid
weight2 = zeros(Float64, 2^numDims)
index2[1] = 1
weight2[1] = 1.0
new(myCutPoints, cut_counts, cuts, index, weight, index2, weight2)
return new(myCutPoints, cut_counts, cuts, index, weight, index2, weight2)
end
end

mutable struct SimplexGrid <: AbstractGrid
RectangleGrid(cutPoints...) = RectangleGrid{length(cutPoints)}(cutPoints...)

mutable struct SimplexGrid{D} <: AbstractGrid{D}
cutPoints::Vector{Vector{Float64}}
cut_counts::Vector{Int}
cuts::Vector{Float64}
Expand All @@ -52,35 +56,38 @@ mutable struct SimplexGrid <: AbstractGrid
ilo::Vector{Int} # indices of cuts below point
n_ind::Vector{Int}

function SimplexGrid(cutPoints...)
function SimplexGrid{D}(cutPoints...) where D
cut_counts = Int[length(cutPoints[i]) for i = 1:length(cutPoints)]
cuts = vcat(cutPoints...)
myCutPoints = Array{Vector{Float64}}(length(cutPoints))
for i = 1:length(cutPoints)
if length(Set(cutPoints[i])) != length(cutPoints[i])
error(@sprintf("Duplicates cutpoints are not allowed (duplicates observed in dimension %d)",i))
end
if !issorted(cutPoints[i])
error("Cut points must be sorted")
end
numDims = length(cutPoints)
@assert numDims == D
for i = 1:numDims
if length(Set(cutPoints[i])) != length(cutPoints[i])
error(@sprintf("Duplicates cutpoints are not allowed (duplicates observed in dimension %d)",i))
end
if !issorted(cutPoints[i])
error("Cut points must be sorted")
end
myCutPoints[i] = cutPoints[i]
end
numDims = length(cutPoints)
index = zeros(Int, numDims+1) # d+1 points for simplex
weight = zeros(Float64, numDims+1)
x_p = zeros(numDims) # residuals
ihi = zeros(Int, numDims) # indicies of cuts above point
ilo = zeros(Int, numDims) # indicies of cuts below point
n_ind = zeros(Int, numDims)
new(myCutPoints, cut_counts, cuts, index, weight, x_p, ihi, ilo, n_ind)
return new(myCutPoints, cut_counts, cuts, index, weight, x_p, ihi, ilo, n_ind)
end
end

SimplexGrid(cutPoints...) = SimplexGrid{length(cutPoints)}(cutPoints...)

Base.length(grid::RectangleGrid) = prod(grid.cut_counts)
Base.length(grid::SimplexGrid) = prod(grid.cut_counts)

dimensions(grid::RectangleGrid) = length(grid.cut_counts)
dimensions(grid::SimplexGrid) = length(grid.cut_counts)
dimensions(grid::AbstractGrid{D}) where D = D
Base.ndims(grid::AbstractGrid{D}) where D = D

label(grid::RectangleGrid) = "multilinear interpolation grid"
label(grid::SimplexGrid) = "simplex interpolation grid"
Expand All @@ -105,7 +112,7 @@ function ind2x(grid::AbstractGrid, ind::Int)
x::Array{Float64}
end

function ind2x!(grid::AbstractGrid, ind::Int, x::Array)
function ind2x!(grid::AbstractGrid, ind::Int, x::AbstractArray)
# Populates x with the value at ind.
# In-place version of ind2x.
# Example:
Expand All @@ -131,7 +138,7 @@ end


# masked interpolation ignores points that are masked
function maskedInterpolate(grid::AbstractGrid, data::DenseArray, x::Vector, mask::BitArray{1})
function maskedInterpolate(grid::AbstractGrid, data::DenseArray, x::AbstractVector, mask::BitArray{1})
index, weight = interpolants(grid, x)
val = 0
totalWeight = 0
Expand All @@ -145,14 +152,14 @@ function maskedInterpolate(grid::AbstractGrid, data::DenseArray, x::Vector, mask
return val / totalWeight
end

interpolate(grid::AbstractGrid, data::Matrix, x::Vector) = interpolate(grid, map(Float64, data[:]), x)
interpolate(grid::AbstractGrid, data::Matrix, x::AbstractVector) = interpolate(grid, map(Float64, data[:]), x)

function interpolate(grid::AbstractGrid, data::DenseArray, x::Vector)
function interpolate(grid::AbstractGrid, data::DenseArray, x::AbstractVector)
index, weight = interpolants(grid, x)
dot(data[index], weight)
end

function interpolants(grid::RectangleGrid, x::Vector)
function interpolants(grid::RectangleGrid, x::AbstractVector)
cut_counts = grid.cut_counts
cuts = grid.cuts

Expand Down Expand Up @@ -220,7 +227,7 @@ function interpolants(grid::RectangleGrid, x::Vector)
grid.index::Vector{Int}, grid.weight::Vector{Float64}
end

function interpolants(grid::SimplexGrid, x::Vector)
function interpolants(grid::SimplexGrid, x::AbstractVector)

weight = grid.weight
index = grid.index
Expand Down Expand Up @@ -318,14 +325,11 @@ function interpolants(grid::SimplexGrid, x::Vector)
return index::Vector{Int}, weight::Vector{Float64}
end

# Returns a matrix of size (num_vertices x grid_dimension)
# where the ith row represents the vertex corresponding to the ith index of grid data
"Return a vector of SVectors where the ith vector represents the vertex corresponding to the ith index of grid data."
function vertices(grid::AbstractGrid)
n_dims = dimensions(grid)
mem = Array{Float64,2}(n_dims, length(grid))

vertex_list::Array{Float64,2} = Array{Float64,2}(length(grid),dimensions(grid))
n_dims::Int = dimensions(grid)

# Iterate over the number of vertices in a grid
for idx = 1 : length(grid)
this_idx::Int = idx-1

Expand All @@ -334,11 +338,17 @@ function vertices(grid::AbstractGrid)
for j = 1 : n_dims
cut_idx::Int = this_idx % grid.cut_counts[j]
this_idx = div(this_idx,grid.cut_counts[j])
vertex_list[idx,j] = grid.cutPoints[j][cut_idx+1]
mem[j, idx] = grid.cutPoints[j][cut_idx+1]
end
end

return vertex_list
#=
This relies on the memory layout of Matrix to stay the same, so is a
possible source of future errors. However, it is documented
(http://juliaarrays.github.io/StaticArrays.jl/stable/pages/
api.html#Arrays-of-static-arrays-1), and tests should catch these errors.
=#
return reinterpret(SVector{n_dims, Float64}, mem, (length(grid),))
end


Expand Down
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ function simplexMagic(NDISC::Int=20, NPOINTS::Int=3, checkFileName::AbstractStri

sInterpValTest = readdlm(checkFileName)

testErr = sum(abs(sInterpVal-sInterpValTest))
testErr = sum(abs, sInterpVal-sInterpValTest)

if (testErr > eps)
display("Failed Simplex Comparison Test")
Expand Down Expand Up @@ -317,12 +317,12 @@ end
# by comparing against ind2x for each unrolled index
function test_vertices_ordering(grid)

grid_verts = vertices(grid)
grid_verts = @inferred vertices(grid)

@test length(grid_verts) == length(grid)*dimensions(grid)
@test length(grid_verts) == length(grid)

for i = 1 : length(grid)
@test grid_verts[i,:] == ind2x(grid,i)
@test grid_verts[i] == ind2x(grid,i)
end

return true
Expand Down

0 comments on commit 2eb4e4e

Please sign in to comment.