diff --git a/examples/lj_forces.jl b/examples/lj_forces.jl index 63c0dd6..3acb292 100644 --- a/examples/lj_forces.jl +++ b/examples/lj_forces.jl @@ -42,5 +42,5 @@ command(lmp, "compute pot_e all pe") command(lmp, "run 0") # extract output -forces = gather(lmp, "f") -energies = gather(lmp, "pot_e") +forces = gather(lmp, "f", LAMMPS_DOUBLE_2D) +energies = extract_compute(lmp, "pot_e", STYLE_GLOBAL, TYPE_SCALAR) diff --git a/examples/snap.jl b/examples/snap.jl index ef94ba9..f1246e8 100644 --- a/examples/snap.jl +++ b/examples/snap.jl @@ -40,7 +40,7 @@ function run_snap(lmp, path, rcut, twojmax) """) ## Extract bispectrum - bs = gather(lmp, "SNA", Float64) + bs = gather(lmp, "c_SNA", LAMMPS_DOUBLE_2D) return bs end diff --git a/src/LAMMPS.jl b/src/LAMMPS.jl index d9d3661..c270817 100644 --- a/src/LAMMPS.jl +++ b/src/LAMMPS.jl @@ -1,11 +1,87 @@ module LAMMPS + import MPI +using Preferences + include("api.jl") -export LMP, command, get_natoms, extract_atom, extract_compute, extract_global, - gather, scatter!, group_to_atom_ids, get_category_ids - -using Preferences +export +# Core + LMP, + command, + get_natoms, +# Gather/Scatter operations + gather, + scatter!, + gather_angles, + gather_bonds, + gather_dihedrals, + gather_impropers, +# Extracts + extract_setting, + extract_atom, + extract_compute, + extract_global, + extract_variable, +# Utilities + group_to_atom_ids, + get_category_ids, +# Datatypes + LAMMPS_INT, + LAMMPS_INT_2D, + LAMMPS_DOUBLE, + LAMMPS_DOUBLE_2D, + LAMMPS_INT64, + LAMMPS_INT64_2D, + LAMMPS_STRING, +# Types + TYPE_SCALAR, + TYPE_VECTOR, + TYPE_ARRAY, + SIZE_COLS, + SIZE_ROWS, + SIZE_VECTOR, +# Styles + STYLE_GLOBAL, + STYLE_ATOM, + STYLE_LOCAL, +# Variables + VARIABLE_EQUAL, + VARIABLE_ATOM, + VARIABLE_VECTOR, + VARIABLE_STRING + +struct _LMP_DATATYPE{N} end + +const LAMMPS_INT = _LMP_DATATYPE{0}() +const LAMMPS_INT_2D = _LMP_DATATYPE{1}() +const LAMMPS_DOUBLE = _LMP_DATATYPE{2}() +const LAMMPS_DOUBLE_2D = _LMP_DATATYPE{3}() +const LAMMPS_INT64 = _LMP_DATATYPE{4}() +const LAMMPS_INT64_2D = _LMP_DATATYPE{5}() +const LAMMPS_STRING = _LMP_DATATYPE{6}() + +struct _LMP_TYPE{N} end + +const TYPE_SCALAR = _LMP_TYPE{0}() +const TYPE_VECTOR = _LMP_TYPE{1}() +const TYPE_ARRAY = _LMP_TYPE{2}() +const SIZE_VECTOR = _LMP_TYPE{3}() +const SIZE_ROWS = _LMP_TYPE{4}() +const SIZE_COLS = _LMP_TYPE{5}() + +struct _LMP_STYLE{N} end + +const STYLE_GLOBAL = _LMP_STYLE{0}() +const STYLE_ATOM = _LMP_STYLE{1}() +const STYLE_LOCAL = _LMP_STYLE{2}() + +struct LMP_VARIABLE{N} end + +const VARIABLE_EQUAL = LMP_VARIABLE{0}() +const VARIABLE_ATOM = LMP_VARIABLE{1}() +const VARIABLE_VECTOR = LMP_VARIABLE{2}() +const VARIABLE_STRING = LMP_VARIABLE{3}() """ locate() @@ -97,13 +173,10 @@ end LMP(f::Function, args=String[], comm=nothing) Create a new LAMMPS instance and call `f` on that instance while returning the result from `f`. -This constructor closes the LAMMPS instance immediately after `f` has executed. """ function LMP(f::Function, args=String[], comm=nothing) lmp = LMP(args, comm) - result = f(lmp) - close!(lmp) - return result + return f(lmp) end function version(lmp::LMP) @@ -181,238 +254,213 @@ function get_natoms(lmp::LMP) Int64(API.lammps_get_natoms(lmp)) end -function dtype2type(dtype::API._LMP_DATATYPE_CONST) - if dtype == API.LAMMPS_INT - type = Ptr{Int32} - elseif dtype == API.LAMMPS_INT_2D - type = Ptr{Ptr{Int32}} - elseif dtype == API.LAMMPS_INT64 - type = Ptr{Int64} - elseif dtype == API.LAMMPS_INT64_2D - type = Ptr{Ptr{Int64}} - elseif dtype == API.LAMMPS_DOUBLE - type = Ptr{Float64} - elseif dtype == API.LAMMPS_DOUBLE_2D - type = Ptr{Ptr{Float64}} - elseif dtype == API.LAMMPS_STRING - type = Ptr{Cchar} - else - @assert false "Unknown dtype: $dtype" - end - return type +function int2type(dtype) + dtype == 0 && return LAMMPS_INT + dtype == 1 && return LAMMPS_INT_2D + dtype == 2 && return LAMMPS_DOUBLE + dtype == 3 && return LAMMPS_DOUBLE_2D + dtype == 4 && return LAMMPS_INT64 + dtype == 5 && return LAMMPS_INT64_2D + dtype == 6 && return LAMMPS_STRING + + error("Unknown lammps data type: $dtype") end -""" - extract_global(lmp, name, dtype=nothing) -""" -function extract_global(lmp::LMP, name, dtype=nothing) - if dtype === nothing - dtype = API.lammps_extract_global_datatype(lmp, name) - end - dtype = API._LMP_DATATYPE_CONST(dtype) - type = dtype2type(dtype) +function type2julia(type::_LMP_DATATYPE) + type == LAMMPS_INT && return Vector{Int32} + type == LAMMPS_INT_2D && return Matrix{Int32} + type == LAMMPS_DOUBLE && return Vector{Float64} + type == LAMMPS_DOUBLE_2D && return Matrix{Float64} + type == LAMMPS_INT64 && return Vector{Int64} + type == LAMMPS_INT64_2D && return Matrix{Int64} + type == LAMMPS_STRING && return String +end - ptr = API.lammps_extract_global(lmp, name) - ptr = reinterpret(type, ptr) +function array2type(array) + array === Vector{Int32} && return LAMMPS_INT + array === Matrix{Int32} && return LAMMPS_INT_2D + array === Vector{Float64} && return LAMMPS_DOUBLE + array === Matrix{Float64} && return LAMMPS_DOUBLE_2D + array === Vector{Int64} && return LAMMPS_INT64 + array === Matrix{Int64} && return LAMMPS_INT64_2D + array === String && return LAMMPS_STRING +end - if ptr !== C_NULL - if dtype == API.LAMMPS_STRING - return Base.unsafe_string(ptr) - end - # TODO: deal with non-scalar data - return Base.unsafe_load(ptr) - end +is_2D(N::Integer) = N in (1, 3, 5) +is_2D(::_LMP_DATATYPE{N}) where N = N in (1, 3, 5) +Base.Int(::_LMP_DATATYPE{N}) where N = N +Base.Int(::LMP_VARIABLE{N}) where N = N +Base.Int(::_LMP_TYPE{N}) where N = N +Base.Int(::_LMP_STYLE{N}) where N = N + +function lammps_reinterpret(T::_LMP_DATATYPE, ptr::Ptr) + # we're pretty much guaranteed to call lammps_reinterpret after reciving a pointer + # from LAMMPS. So this is a good spot catch NULL-pointers and avoid Segfaults + ptr == C_NULL && error("reinterpreting NULL-pointer!") + + T === LAMMPS_INT && return Base.reinterpret(Ptr{Int32}, ptr) + T === LAMMPS_INT_2D && return Base.reinterpret(Ptr{Ptr{Int32}}, ptr) + T === LAMMPS_DOUBLE && return Base.reinterpret(Ptr{Float64}, ptr) + T === LAMMPS_DOUBLE_2D && return Base.reinterpret(Ptr{Ptr{Float64}}, ptr) + T === LAMMPS_INT64 && return Base.reinterpret(Ptr{Int64}, ptr) + T === LAMMPS_INT64_2D && return Base.reinterpret(Ptr{Ptr{Int64}}, ptr) + T === LAMMPS_STRING && return Base.reinterpret(Ptr{UInt8}, ptr) end -function unsafe_wrap(ptr, shape) - if length(shape) > 1 - # We got a list of ptrs, - # but the first pointer points to the whole data - ptr = Base.unsafe_load(ptr) +""" + extract_global(lmp::LMP, name::String, dtype::_LMP_DATATYPE; copy=true) +""" +function extract_global(lmp::LMP, name::String, dtype::_LMP_DATATYPE; copy=true) + @assert API.lammps_extract_global_datatype(lmp, name) == Int(dtype) + + ptr = lammps_reinterpret(dtype, API.lammps_extract_global(lmp, name)) - @assert length(shape) == 2 + dtype == LAMMPS_STRING && return lammps_unsafe_string(ptr, copy) - # Note: Julia like Fortran is column-major - # so the data is transposed from Julia's perspective - shape = reverse(shape) + if name in ("boxlo", "boxhi", "sublo", "subhi", "sublo_lambda", "subhi_lambda", "periodicity") + length = 3 + elseif name in ("special_lj", "special_coul") + length = 4 + else + length = 1 end - # TODO: Who is responsible for freeing this data - array = Base.unsafe_wrap(Array, ptr, shape, own=false) - return array + return lammps_unsafe_wrap(ptr, length, copy) end -""" - extract_atom(lmp, name, dtype=nothing, axes1, axes2) -""" -function extract_atom(lmp::LMP, name, - dtype::Union{Nothing, API._LMP_DATATYPE_CONST} = nothing, - axes1=nothing, axes2=nothing) +function lammps_unsafe_string(ptr::Ptr, copy=true) + result = Base.unsafe_string(ptr) + return copy ? deepcopy(result) : result +end +function lammps_unsafe_wrap(ptr::Ptr{<:Real}, shape::Integer, copy=true) + result = Base.unsafe_wrap(Array, ptr, shape, own=false) + return copy ? Base.copy(result) : result +end - if dtype === nothing - dtype = API.lammps_extract_atom_datatype(lmp, name) - dtype = API._LMP_DATATYPE_CONST(dtype) - end +function lammps_unsafe_wrap(ptr::Ptr{<:Ptr{T}}, shape::NTuple{2}, copy=true) where T + (count, ndata) = shape - if axes1 === nothing - if name == "mass" - axes1 = extract_global(lmp, "ntypes") + 1 - else - axes1 = extract_global(lmp, "nlocal") % Int - end - end + ndata == 0 && return Matrix{T}(undef, count, ndata) - if axes2 === nothing - if dtype in (API.LAMMPS_INT_2D, API.LAMMPS_INT64_2D, API.LAMMPS_DOUBLE_2D) - # TODO: Other fields? - if name in ("x", "v", "f", "angmom", "torque", "csforce", "vforce") - axes2 = 3 - else - axes2 = 2 - end - end - end + pointers = Base.unsafe_wrap(Array, ptr, ndata) - if axes2 !== nothing - shape = (axes1, axes2) - else - shape = (axes1, ) - end + @assert all(diff(pointers) .== count*sizeof(T)) + result = Base.unsafe_wrap(Array, pointers[1], shape, own=false) - type = dtype2type(dtype) - ptr = API.lammps_extract_atom(lmp, name) - ptr = reinterpret(type, ptr) + return copy ? Base.copy(result) : result +end + +""" + extract_setting(lmp::LMP, name::String) - unsafe_wrap(ptr, shape) + +""" +function extract_setting(lmp::LMP, name::String) + return API.lammps_extract_setting(lmp, name) end -function unsafe_extract_compute(lmp::LMP, name, style, type) - if type == API.LMP_TYPE_SCALAR - if style == API.LMP_STYLE_GLOBAL - dtype = Ptr{Float64} - elseif style == API.LMP_STYLE_LOCAL - dtype = Ptr{Cint} - elseif style == API.LMP_STYLE_ATOM - return nothing - end - extract = true - elseif type == API.LMP_TYPE_VECTOR - dtype = Ptr{Float64} - extract = false - elseif type == API.LMP_TYPE_ARRAY - dtype = Ptr{Ptr{Float64}} - extract = false - elseif type == API.LMP_SIZE_COLS - dtype = Ptr{Cint} - extract = true - elseif type == API.LMP_SIZE_ROWS || - type == API.LMP_SIZE_VECTOR - if style == API.LMP_STYLE_ATOM - return nothing - end - dtype = Ptr{Cint} - extract = true - else - @assert false "Unknown type: $type" - end +""" + extract_atom(lmp::LMP, name::String, dtype::_LMP_DATATYPE; copy=true) +""" +function extract_atom(lmp::LMP, name::String, dtype::_LMP_DATATYPE; copy=true) + @assert API.lammps_extract_atom_datatype(lmp, name) == Int(dtype) - ptr = API.lammps_extract_compute(lmp, name, style, type) - ptr == C_NULL && check(lmp) + ptr = lammps_reinterpret(dtype, API.lammps_extract_atom(lmp, name)) + @assert ptr != C_NULL - if ptr == C_NULL - error("Could not extract_compute $name with $style and $type") + if name == "mass" + length = extract_global(lmp, "ntypes", LAMMPS_INT, copy=false)[] + ptr += sizeof(Float64) # Scarry pointer arithemtic + result = lammps_unsafe_wrap(ptr, length, false) + + return copy ? Base.copy(result) : result end - ptr = reinterpret(dtype, ptr) - if extract - return Base.unsafe_load(ptr) + length = extract_setting(lmp, "nlocal") + + if is_2D(dtype) + count = name == "quat" ? Int32(4) : Int32(3) # only Quaternions have 4 entries + return lammps_unsafe_wrap(ptr, (count, length), copy) end - return ptr + + return lammps_unsafe_wrap(ptr, length, copy) end """ - extract_compute(lmp, name, style, type) + extract_compute(lmp::LMP, name::String, style::_LMP_STYLE, type::_LMP_TYPE; copy=true) """ -function extract_compute(lmp::LMP, name, style, type) - ptr_or_value = unsafe_extract_compute(lmp, name, style, type) - if style == API.LMP_TYPE_SCALAR - return ptr_or_value +function extract_compute(lmp::LMP, name::String, style::_LMP_STYLE, type::_LMP_TYPE; copy=true) + void_ptr = API.lammps_extract_compute(lmp, name, Int(style), Int(type)) + @assert void_ptr != C_NULL + + if type in (SIZE_COLS, SIZE_ROWS, SIZE_VECTOR) + ptr = lammps_reinterpret(LAMMPS_INT, void_ptr) + return lammps_unsafe_wrap(ptr, 1, copy) end - if ptr_or_value === nothing - return nothing + + if type == TYPE_SCALAR + ptr = lammps_reinterpret(LAMMPS_DOUBLE, void_ptr) + return lammps_unsafe_wrap(ptr, 1, copy) end - ptr = ptr_or_value::Ptr - - if style in (API.LMP_STYLE_GLOBAL, API.LMP_STYLE_LOCAL) - if type == API.LMP_TYPE_VECTOR - nrows = unsafe_extract_compute(lmp, name, style, API.LMP_SIZE_VECTOR) - return unsafe_wrap(ptr, (nrows,)) - elseif type == API.LMP_TYPE_ARRAY - nrows = unsafe_extract_compute(lmp, name, style, API.LMP_SIZE_ROWS) - ncols = unsafe_extract_compute(lmp, name, style, API.LMP_SIZE_COLS) - return unsafe_wrap(ptr, (nrows, ncols)) - end - else style = API.LMP_STYLE_ATOM - nlocal = extract_global(lmp, "nlocal") - if type == API.LMP_TYPE_VECTOR - return unsafe_wrap(ptr, (nlocal,)) - elseif type == API.LMP_TYPE_ARRAY - ncols = unsafe_extract_compute(lmp, name, style, API.LMP_SIZE_COLS) - return unsafe_wrap(ptr, (nlocal, ncols)) - end + + ndata = style == STYLE_ATOM ? + extract_setting(lmp, "nlocal") : + extract_compute(lmp, name, style, TYPE_SCALAR, copy=false)[] + + if type == TYPE_VECTOR + ptr = lammps_reinterpret(LAMMPS_DOUBLE, void_ptr) + return lammps_unsafe_wrap(ptr, ndata, copy) end - return nothing + + count = extract_compute(lmp, name, style, SIZE_COLS)[] + ptr = lammps_reinterpret(LAMMPS_DOUBLE_2D, void_ptr) + + return lammps_unsafe_wrap(ptr, (count, ndata), copy) end """ - extract_variable(lmp::LMP, name, group) + extract_variable(lmp::LMP, name::String, variable::LMP_VARIABLE, group=C_NULL; copy=true) Extracts the data from a LAMMPS variable. When the variable is either an `equal`-style compatible variable, a `vector`-style variable, or an `atom`-style variable, the variable is evaluated and the corresponding value(s) returned. Variables of style `internal` are compatible with `equal`-style variables, if they return a numeric value. For other variable styles, their string value is returned. """ -function extract_variable(lmp::LMP, name::String, group=nothing) - var = API.lammps_extract_variable_datatype(lmp, name) - if var == -1 - throw(KeyError(name)) +function extract_variable(lmp::LMP, name::String, variable::LMP_VARIABLE, group=C_NULL; copy=true) + @assert variable == VARIABLE_ATOM || group == C_NULL "the group parameter is only supported for per atom variables!" + @assert API.lammps_extract_variable_datatype(lmp, name) == Int(variable) + + void_ptr = API.lammps_extract_variable(lmp, name, group) + @assert void_ptr != C_NULL + + if variable == VARIABLE_EQUAL + ptr = lammps_reinterpret(LAMMPS_DOUBLE, void_ptr) + result = unsafe_load(ptr) + API.lammps_free(ptr) + return result end - if group === nothing - group = C_NULL + + if variable == VARIABLE_VECTOR + ndata_ptr = lammps_reinterpret(LAMMPS_INT, API.lammps_extract_variable(lmp, name, "GET_VECTOR_SIZE")) + ndata = unsafe_load(ndata_ptr) + API.lammps_free(ndata_ptr) + + ptr = lammps_reinterpret(LAMMPS_DOUBLE, void_ptr) + return lammps_unsafe_wrap(ptr, ndata, copy) end - if var == API.LMP_VAR_EQUAL - ptr = API.lammps_extract_variable(lmp, name, C_NULL) - val = Base.unsafe_load(Base.unsafe_convert(Ptr{Float64}, ptr)) - API.lammps_free(ptr) - return val - elseif var == API.LMP_VAR_ATOM - nlocal = extract_global(lmp, "nlocal") - ptr = API.lammps_extract_variable(lmp, name, group) - if ptr == C_NULL - error("Group $group for variable $name with style atom not available.") - end - # LAMMPS uses malloc, so and we are taking ownership of this buffer - val = copy(Base.unsafe_wrap(Array, Base.unsafe_convert(Ptr{Float64}, ptr), nlocal; own=false)) - API.lammps_free(ptr) - return val - elseif var == API.LMP_VAR_VECTOR - # TODO Fix lammps docs `GET_VECTOR_SIZE` - ptr = API.lammps_extract_variable(lmp, name, "LMP_SIZE_VECTOR") - if ptr == C_NULL - error("$name is a vector style variable but has no size.") - end - sz = Base.unsafe_load(Base.unsafe_convert(Ptr{Cint}, ptr)) + if variable == VARIABLE_ATOM + ndata = extract_setting(lmp, "nlocal") + + ptr = lammps_reinterpret(LAMMPS_DOUBLE, void_ptr) + result = lammps_unsafe_wrap(ptr, ndata, true) API.lammps_free(ptr) - ptr = API.lammps_extract_variable(lmp, name, C_NULL) - return Base.unsafe_wrap(Array, Base.unsafe_convert(Ptr{Float64}, ptr), sz, own=false) - elseif var == API.LMP_VAR_STRING - ptr = API.lammps_extract_variable(lmp, name, C_NULL) - return Base.unsafe_string(Base.unsafe_convert(Ptr{Cchar}, ptr)) - else - error("Unkown variable style $var") + return result end + + ptr = lammps_reinterpret(LAMMPS_STRING, void_ptr) + return lammps_unsafe_string(ptr, copy) end @deprecate gather_atoms(lmp::LMP, name, T, count) gather(lmp, name, T) @@ -429,28 +477,36 @@ Compute entities have the prefix `c_`, fix entities use the prefix `f_`, and per The returned Array is decoupled from the internal state of the LAMMPS instance. -!!! warning "Type Verification" - Due to how the underlying C-API works, it's not possible to verify the element data-type of fix or compute style data. - Supplying the wrong data-type will not throw an error but will result in nonsensical output - !!! warning "ids" The optional parameter `ids` only works, if there is a map defined. For example by doing: `command(lmp, "atom_modify map yes")` However, LAMMPS only issues a warning if that's the case, which unfortuately cannot be detected through the underlying API. Starting form LAMMPS version `17 Apr 2024` this should no longer be an issue, as LAMMPS then throws an error instead of a warning. """ -function gather(lmp::LMP, name::String, T::Union{Type{Int32}, Type{Float64}}, ids::Union{Nothing, Array{Int32}}=nothing) +function gather(lmp::LMP, name::String, T::_LMP_DATATYPE, ids::Union{Nothing, Array{Int32}}=nothing) name == "mass" && error("scattering/gathering mass is currently not supported! Use `extract_atom()` instead.") count = _get_count(lmp, name) - _T = _get_T(lmp, name) + _dtype = _get_dtype(lmp, name) - @assert ismissing(_T) || _T == T "Expected data type $_T got $T instead." + @assert Int(T) in _dtype "Expected data type $(int2type.(_dtype)) got $T instead." + count > 1 && T in (LAMMPS_DOUBLE, LAMMPS_INT) && error("1") - dtype = (T === Float64) + dtype = T in (LAMMPS_DOUBLE, LAMMPS_DOUBLE_2D) natoms = get_natoms(lmp) ndata = isnothing(ids) ? natoms : length(ids) - data = Matrix{T}(undef, (count, ndata)) + + if T == LAMMPS_INT + data = Vector{Int32}(undef, ndata) + elseif T == LAMMPS_DOUBLE + data = Vector{Int32}(undef, ndata) + elseif T == LAMMPS_INT_2D + data = Matrix{Float64}(undef, (count, ndata)) + elseif T == LAMMPS_DOUBLE_2D + data = Matrix{Float64}(undef, (count, ndata)) + else + error("2") + end if isnothing(ids) API.lammps_gather(lmp, name, dtype, count, data) @@ -463,6 +519,46 @@ function gather(lmp::LMP, name::String, T::Union{Type{Int32}, Type{Float64}}, id return data end +function gather_bonds(lmp::LMP) + nbonds = extract_global(lmp, "nbonds", LAMMPS_INT64, copy=false)[] + data = Matrix{Int32}(undef, (3, nbonds)) + API.lammps_gather_bonds(lmp, data) + return data +end + +function gather_angles(lmp::LMP) + nangles = extract_global(lmp, "nangles", LAMMPS_INT64, copy=false)[] + data = Matrix{Int32}(undef, (4, nangles)) + API.lammps_gather_angles(lmp, data) + return data +end + +function gather_dihedrals(lmp::LMP) + ndihedrals = extract_global(lmp, "ndihedrals", LAMMPS_INT64, copy=false)[] + data = Matrix{Int32}(undef, (5, ndihedrals)) + API.lammps_gather_dihedrals(lmp, data) + return data +end + +function gather_impropers(lmp::LMP) + nimpropers = extract_global(lmp, "nimpropers", LAMMPS_INT64, copy=false)[] + data = Matrix{Int32}(undef, (5, nimpropers)) + API.lammps_gather_impropers(lmp, data) + return data +end + +function create_atoms(lmp::LMP, type, x; id=nothing, v=nothing, image=nothing, bexpand=false) + natoms = length(type) + + @assert size(x) == (3, natoms) + + isnothing(id) ? id = C_NULL : @assert size(id) == natoms + isnothing(v) ? v = C_NULL : @assert size(v) == (3, natoms) + isnothing(image) ? image = C_NULL : @assert size(image) = natoms + + return API.lammps_create_atoms(lmp, natoms, id, type, x, v, image, bexpand) +end + """ scatter!(lmp::LMP, name::String, data::VecOrMat{T}, ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64} @@ -472,23 +568,19 @@ The optional parameter `ids` determines to which subset of atoms the data will b Compute entities have the prefix `c_`, fix entities use the prefix `f_`, and per-atom entites have no prefix. -!!! warning "Type Verification" - Due to how the underlying C-API works, it's not possible to verify the element data-type of fix or compute style data. - Supplying the wrong data-type will not throw an error but will result in nonsensical date being supplied to the LAMMPS instance. - !!! warning "ids" The optional parameter `ids` only works, if there is a map defined. For example by doing: `command(lmp, "atom_modify map yes")` However, LAMMPS only issues a warning if that's the case, which unfortuately cannot be detected through the underlying API. Starting form LAMMPS version `17 Apr 2024` this should no longer be an issue, as LAMMPS then throws an error instead of a warning. """ -function scatter!(lmp::LMP, name::String, data::VecOrMat{T}, ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64} +function scatter!(lmp::LMP, name::String, data::VecOrMat{T}, ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Float64, Int32} name == "mass" && error("scattering/gathering mass is currently not supported! Use `extract_atom()` instead.") count = _get_count(lmp, name) - _T = _get_T(lmp, name) + _T = _get_dtype(lmp, name) - @assert ismissing(_T) || _T == T "Expected data type $_T got $T instead." + @assert Int(array2type(typeof(data))) in _T dtype = (T === Float64) natoms = get_natoms(lmp) @@ -518,47 +610,55 @@ function _get_count(lmp::LMP, name::String) if startswith(name, r"[f,c]_") if name[1] == 'c' API.lammps_has_id(lmp, "compute", name[3:end]) != 1 && error("Unknown per atom compute $name") - count_ptr = API.lammps_extract_compute(lmp::LMP, name[3:end], API.LMP_STYLE_ATOM, API.LMP_SIZE_COLS) else API.lammps_has_id(lmp, "fix", name[3:end]) != 1 && error("Unknown per atom fix $name") - count_ptr = API.lammps_extract_fix(lmp::LMP, name[3:end], API.LMP_STYLE_ATOM, API.LMP_SIZE_COLS, 0, 0) end - check(lmp) - - count_ptr = reinterpret(Ptr{Cint}, count_ptr) + count_ptr == C_NULL && error("compute $name does not have per atom data") + count_ptr = lammps_reinterpret(LAMMPS_INT, count_ptr) count = unsafe_load(count_ptr) # a count of 0 indicates that the entity is a vector. In order to perserve type stability we just treat that as a 1xN Matrix. return count == 0 ? 1 : count - elseif name in ("mass", "id", "type", "mask", "image", "molecule", "q", "radius", "rmass", "ellipsoid", "line", "tri", "body", "temperature", "heatflow") - return 1 - elseif name in ("x", "v", "f", "mu", "omega", "angmom", "torque") - return 3 - elseif name == "quat" - return 4 else - error("Unknown per atom property $name") + dtype = API.lammps_extract_atom_datatype(lmp, name) + dtype == -1 && error("Unkown per atom property $name") + + name == "quat" && return 4 + is_2D(dtype) && return 3 + return 1 + end end -function _get_T(lmp::LMP, name::String) +function _get_dtype(lmp::LMP, name::String) if startswith(name, r"[f,c]_") - return missing # As far as I know, it's not possible to determine the datatype of computes or fixes at runtime + return (Int(LAMMPS_DOUBLE), Int(LAMMPS_DOUBLE_2D)) + else + return (API.lammps_extract_atom_datatype(lmp, name), ) end +end - type = API.lammps_extract_atom_datatype(lmp, name) - check(lmp) +function decode_image_flags(images::Vector{<:Integer}) + data = Matrix{Int32}(undef, (3, length(images))) - if type in (API.LAMMPS_INT, API.LAMMPS_INT_2D) - return Int32 - elseif type in (API.LAMMPS_DOUBLE, API.LAMMPS_DOUBLE_2D) - return Float64 - else - error("Unkown per atom property $name") + for (i, image) in pairs(images) + data_view = @view data[:, i] + API.lammps_decode_image_flags(image, data_view) end + return data +end + +function encode_image_flags(images::Matrix{<:Integer}) + @assert size(images, 1) == 3 + + return [API.lammps_encode_image_flags(img[1], img[2], img[3]) for img in eachcol(images)] +end + +function is_running(lmp) + return API.lammps_is_running(lmp) > 0 end """ @@ -579,7 +679,7 @@ function group_to_atom_ids(lmp::LMP, group::String) API.lammps_id_name(lmp, "group", idx, buffer, buffer_size) buffer != name_padded && continue - mask = gather(lmp, "mask", Int32)[:] .& (1 << idx) .!= 0 + mask = gather(lmp, "mask", LAMMPS_INT) .& (1 << idx) .!= 0 all_ids = UnitRange{Int32}(1, get_natoms(lmp)) return all_ids[mask] diff --git a/test/runtests.jl b/test/runtests.jl index 8cd826b..4fb5836 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,26 +17,28 @@ end @testset "Variables" begin LMP(["-screen", "none"]) do lmp - command(lmp, "box tilt large") - command(lmp, "region cell block 0 1.0 0 1.0 0 1.0 units box") - command(lmp, "create_box 1 cell") - command(lmp, "create_atoms 1 random 10 1 NULL") - command(lmp, "compute press all pressure NULL pair"); - command(lmp, "fix press all ave/time 1 1 1 c_press mode vector"); - - command(lmp, "variable var1 equal 1.0") - command(lmp, "variable var2 string \"hello\"") - command(lmp, "variable var3 atom x") - # TODO: x is 3d, how do we access more than the first dims - command(lmp, "variable var4 vector f_press") - - @test LAMMPS.extract_variable(lmp, "var1") == 1.0 - @test LAMMPS.extract_variable(lmp, "var2") == "hello" - x = LAMMPS.extract_atom(lmp, "x") - x_var = LAMMPS.extract_variable(lmp, "var3") + command(lmp, """ + box tilt large + region cell block 0 1.0 0 1.0 0 1.0 units box + create_box 1 cell + create_atoms 1 random 10 1 NULL + compute press all pressure NULL pair + fix press all ave/time 1 1 1 c_press mode vector + + variable var1 equal 1.0 + variable var2 string \"hello\" + variable var3 atom x + # TODO: x is 3d, how do we access more than the first dims + variable var4 vector f_press + """) + + @test extract_variable(lmp, "var1", VARIABLE_EQUAL) == 1.0 + @test extract_variable(lmp, "var2", VARIABLE_STRING) == "hello" + x = extract_atom(lmp, "x", LAMMPS_DOUBLE_2D) + x_var = extract_variable(lmp, "var3", VARIABLE_ATOM) @test length(x_var) == 10 @test x_var == x[1, :] - press = LAMMPS.extract_variable(lmp, "var4") + press = LAMMPS.extract_variable(lmp, "var4", VARIABLE_VECTOR) @test press isa Vector{Float64} end end @@ -44,17 +46,20 @@ end @testset "gather/scatter" begin LMP(["-screen", "none"]) do lmp # setting up example data - command(lmp, "atom_modify map yes") - command(lmp, "region cell block 0 3 0 3 0 3") - command(lmp, "create_box 1 cell") - command(lmp, "lattice sc 1") - command(lmp, "create_atoms 1 region cell") - command(lmp, "mass 1 1") + command(lmp, """ + atom_modify map yes + region cell block 0 3 0 3 0 3 + create_box 1 cell + lattice sc 1 + create_atoms 1 region cell + mass 1 1 - command(lmp, "compute pos all property/atom x y z") - command(lmp, "fix pos all ave/atom 10 1 10 c_pos[1] c_pos[2] c_pos[3]") + compute pos all property/atom x y z + fix pos all ave/atom 10 1 10 c_pos[*] - command(lmp, "run 10") + run 10 + """) + data = zeros(Float64, 3, 27) subset = Int32.([2,5,10, 5]) data_subset = ones(Float64, 3, 4) @@ -63,15 +68,15 @@ end subset_bad2 = Int32.([0]) subset_bad_data = ones(Float64, 3,1) - @test_throws AssertionError gather(lmp, "x", Int32) - @test_throws AssertionError gather(lmp, "id", Float64) + @test_throws AssertionError gather(lmp, "x", LAMMPS_INT_2D) + @test_throws AssertionError gather(lmp, "id", LAMMPS_DOUBLE) - @test_throws ErrorException gather(lmp, "nonesense", Float64) - @test_throws ErrorException gather(lmp, "c_nonsense", Float64) - @test_throws ErrorException gather(lmp, "f_nonesense", Float64) + @test_throws ErrorException gather(lmp, "nonesense", LAMMPS_DOUBLE_2D) + @test_throws ErrorException gather(lmp, "c_nonsense", LAMMPS_DOUBLE_2D) + @test_throws ErrorException gather(lmp, "f_nonesense", LAMMPS_DOUBLE_2D) - @test_throws AssertionError gather(lmp, "x", Float64, subset_bad1) - @test_throws AssertionError gather(lmp, "x", Float64, subset_bad2) + @test_throws AssertionError gather(lmp, "x", LAMMPS_DOUBLE_2D, subset_bad1) + @test_throws AssertionError gather(lmp, "x", LAMMPS_DOUBLE_2D, subset_bad2) @test_throws ErrorException scatter!(lmp, "nonesense", data) @test_throws ErrorException scatter!(lmp, "c_nonsense", data) @@ -80,23 +85,23 @@ end @test_throws AssertionError scatter!(lmp, "x", subset_bad_data, subset_bad1) @test_throws AssertionError scatter!(lmp, "x", subset_bad_data, subset_bad2) - @test gather(lmp, "x", Float64) == gather(lmp, "c_pos", Float64) == gather(lmp, "f_pos", Float64) + @test gather(lmp, "x", LAMMPS_DOUBLE_2D) == gather(lmp, "c_pos", LAMMPS_DOUBLE_2D) == gather(lmp, "f_pos", LAMMPS_DOUBLE_2D) - @test gather(lmp, "x", Float64)[:,subset] == gather(lmp, "x", Float64, subset) - @test gather(lmp, "c_pos", Float64)[:,subset] == gather(lmp, "c_pos", Float64, subset) - @test gather(lmp, "f_pos", Float64)[:,subset] == gather(lmp, "f_pos", Float64, subset) + @test gather(lmp, "x", LAMMPS_DOUBLE_2D)[:,subset] == gather(lmp, "x", LAMMPS_DOUBLE_2D, subset) + @test gather(lmp, "c_pos", LAMMPS_DOUBLE_2D)[:,subset] == gather(lmp, "c_pos", LAMMPS_DOUBLE_2D, subset) + @test gather(lmp, "f_pos", LAMMPS_DOUBLE_2D)[:,subset] == gather(lmp, "f_pos", LAMMPS_DOUBLE_2D, subset) scatter!(lmp, "x", data) scatter!(lmp, "f_pos", data) scatter!(lmp, "c_pos", data) - @test gather(lmp, "x", Float64) == gather(lmp, "c_pos", Float64) == gather(lmp, "f_pos", Float64) == data + @test gather(lmp, "x", LAMMPS_DOUBLE_2D) == gather(lmp, "c_pos", LAMMPS_DOUBLE_2D) == gather(lmp, "f_pos", LAMMPS_DOUBLE_2D) == data scatter!(lmp, "x", data_subset, subset) scatter!(lmp, "c_pos", data_subset, subset) scatter!(lmp, "f_pos", data_subset, subset) - @test gather(lmp, "x", Float64, subset) == gather(lmp, "c_pos", Float64, subset) == gather(lmp, "f_pos", Float64, subset) == data_subset + @test gather(lmp, "x", LAMMPS_DOUBLE_2D, subset) == gather(lmp, "c_pos", LAMMPS_DOUBLE_2D, subset) == gather(lmp, "f_pos", LAMMPS_DOUBLE_2D, subset) == data_subset end end