Skip to content

Commit

Permalink
Fix handling of Unions in construct/deconstruct (#21)
Browse files Browse the repository at this point in the history
* work

* Fix handling of Unions in deconstruct
  • Loading branch information
quinnj authored Sep 1, 2022
1 parent e02a4d4 commit b3d30c8
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 12 deletions.
43 changes: 31 additions & 12 deletions src/Strapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ function construct(::StructTypes.CustomStruct, ::Type{T}, row, prefix=Symbol(),
end

function construct(::StructTypes.Struct, ::Type{T}, row, prefix=Symbol(), offset=Ref(0); kw...) where {T}
# @print 3 (T, prefix, offset)
# @show (T, T isa Union, prefix, offset)
return StructTypes.construct(T) do i, nm, TT
x = construct(StructTypes.StructType(TT), TT, row, Symbol(prefix, StructTypes.fieldprefix(T, nm)), offset, offset[] + 1, Symbol(prefix, nm); kw...)
# @print 3 x
Expand All @@ -155,7 +155,11 @@ end

function construct(::ST, ::Type{T}, row, prefix, offset, i, nm; kw...) where {ST <: Union{StructTypes.Struct, StructTypes.Mutable, StructTypes.CustomStruct}, T}
# @print 3 (T, prefix, offset, i, nm)
construct(ST(), T, row, prefix, offset; kw...)
if T isa Union
construct(ST(), Any, row, prefix, offset, i, nm; kw...)
else
construct(ST(), T, row, prefix, offset; kw...)
end
end

function construct!(::StructTypes.CustomStruct, x::T, row, prefix=Symbol(), offset=Ref(0); kw...) where {T}
Expand Down Expand Up @@ -374,6 +378,7 @@ getfieldvalue(::StructTypes.ArrayType, x, ind, fn) = isempty(x) ? missing : getf

function getfieldvalue(::Union{StructTypes.Struct, StructTypes.Mutable}, x, ind, fn)
val = Core.getfield(x, fn.index)
# @show val, ind, x, fn, fn.index, fn.subfield
return getfieldvalue(val, ind, fn.subfield)
end

Expand All @@ -383,6 +388,7 @@ function getfieldvalue(::StructTypes.CustomStruct, x, ind, fn)
end

function getfieldvalue(ST, x, ind, fn)
# @show ST, x, ind, fn
@assert fn === nothing
return x
end
Expand Down Expand Up @@ -482,11 +488,11 @@ struct EmptyArrayTypeValue end
function (f::DeconstructClosure)(::Union{StructTypes.Struct, StructTypes.Mutable}, x::T) where {T}
# x is root object
reset!(f)
f.parentType = T
StructTypes.foreachfield(x) do i, nm, TT, v
# each root field is a separate prefix/fieldnode ancestry branch
f.prefix = Symbol()
f.fieldnode = nothing
f.parentType = T
f(i, nm, TT, v)
end
return
Expand Down Expand Up @@ -515,18 +521,28 @@ end
(f::DeconstructClosure)(i, nm, TT, v; kw...) = f(StructTypes.StructType(TT), i, nm, TT, v)
(f::DeconstructClosure)(i, nm, TT; kw...) = f(StructTypes.StructType(TT), i, nm, TT, EmptyArrayTypeValue())
function (f::DeconstructClosure)(::Union{StructTypes.Struct, StructTypes.Mutable}, i, nm, TT, v)
f.prefix = Symbol(f.prefix, StructTypes.fieldprefix(f.parentType, nm))
f.parentType = TT
f.fieldnode = getfieldnode(f, FieldNode(i, nm, nothing))
StructTypes.foreachfield(f, v)
prefix = Symbol(f.prefix, StructTypes.fieldprefix(f.parentType, nm))
fieldnode = getfieldnode(f, FieldNode(i, nm, nothing))
StructTypes.foreachfield(v) do i2, nm2, TT2, v2
# reset prefix, parentType, fieldnode for each field
f.prefix = prefix
f.parentType = TT
f.fieldnode = fieldnode
f(i2, nm2, TT2, v2)
end
return
end

function (f::DeconstructClosure)(::Union{StructTypes.Struct, StructTypes.Mutable}, i, nm, TT, ::EmptyArrayTypeValue)
f.prefix = Symbol(f.prefix, StructTypes.fieldprefix(f.parentType, nm))
f.parentType = TT
f.fieldnode = getfieldnode(f, FieldNode(i, nm, nothing))
StructTypes.foreachfield(f, TT)
prefix = Symbol(f.prefix, StructTypes.fieldprefix(f.parentType, nm))
fieldnode = getfieldnode(f, FieldNode(i, nm, nothing))
StructTypes.foreachfield(TT) do i2, nm2, TT2, v2
# reset prefix, parentType, fieldnode for each field
f.prefix = prefix
f.parentType = TT
f.fieldnode = fieldnode
f(i2, nm2, TT2, v2)
end
return
end

Expand Down Expand Up @@ -561,7 +577,10 @@ function (f::DeconstructClosure)(::StructTypes.ArrayType, i, nm, TT, v)
return
end

function (f::DeconstructClosure)(ST, i, nm, TT, v)
(f::DeconstructClosure)(ST::StructTypes.Struct, i, nm, U::Union, v) = deconstruct_leaf(f, i, nm, U)
(f::DeconstructClosure)(ST, i, nm, TT, v) = deconstruct_leaf(f, i, nm, TT)

function deconstruct_leaf(f::DeconstructClosure, i, nm, TT)
if f.i > length(f.names)
# first time deconstructing obj
push!(f.names, Symbol(f.prefix, nm))
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,4 @@ tbl = Strapping.deconstruct(w) |> Tables.columntable
@test tbl == (a = [1], b = ["hey"])
w2 = Strapping.construct(Wrapper, tbl)
@test w == w2

0 comments on commit b3d30c8

Please sign in to comment.