Skip to content

Commit

Permalink
Prevent panic on self-referencing structs/map/slices during Marshal
Browse files Browse the repository at this point in the history
  • Loading branch information
tsuperis3112 committed Dec 31, 2024
1 parent 71ac162 commit 6f40716
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 13 deletions.
2 changes: 1 addition & 1 deletion reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ func _createEncoderOfType(ctx *ctx, typ reflect2.Type) ValEncoder {
kind := typ.Kind()
switch kind {
case reflect.Interface:
return &dynamicEncoder{typ}
return &dynamicEncoder{valType: typ, seen: make(map[unsafe.Pointer]bool, 1)}
case reflect.Struct:
return encoderOfStruct(ctx, typ)
case reflect.Array:
Expand Down
11 changes: 10 additions & 1 deletion reflect_dynamic.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
package jsoniter

import (
"github.com/modern-go/reflect2"
"reflect"
"unsafe"

"github.com/modern-go/reflect2"
)

type dynamicEncoder struct {
valType reflect2.Type
seen map[unsafe.Pointer]bool
}

func (encoder *dynamicEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
if encoder.seen[ptr] {
stream.Error = ErrEncounterCycle
return
}
encoder.seen[ptr] = true
defer delete(encoder.seen, ptr)

obj := encoder.valType.UnsafeIndirect(ptr)
stream.WriteVal(obj)
}
Expand Down
5 changes: 3 additions & 2 deletions reflect_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ package jsoniter

import (
"fmt"
"github.com/modern-go/reflect2"
"reflect"
"sort"
"strings"
"unicode"
"unsafe"

"github.com/modern-go/reflect2"
)

var typeDecoders = map[string]ValDecoder{}
Expand Down Expand Up @@ -325,7 +326,7 @@ func _getTypeEncoderFromExtension(ctx *ctx, typ reflect2.Type) ValEncoder {
typePtr := typ.(*reflect2.UnsafePtrType)
encoder := typeEncoders[typePtr.Elem().String()]
if encoder != nil {
return &OptionalEncoder{encoder}
return &OptionalEncoder{ValueEncoder: encoder, seen: make(map[unsafe.Pointer]bool, 1)}
}
}
return nil
Expand Down
18 changes: 14 additions & 4 deletions reflect_optional.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package jsoniter

import (
"github.com/modern-go/reflect2"
"unsafe"

"github.com/modern-go/reflect2"
)

func decoderOfOptional(ctx *ctx, typ reflect2.Type) ValDecoder {
Expand All @@ -16,7 +17,7 @@ func encoderOfOptional(ctx *ctx, typ reflect2.Type) ValEncoder {
ptrType := typ.(*reflect2.UnsafePtrType)
elemType := ptrType.Elem()
elemEncoder := encoderOfType(ctx, elemType)
encoder := &OptionalEncoder{elemEncoder}
encoder := &OptionalEncoder{ValueEncoder: elemEncoder, seen: make(map[unsafe.Pointer]bool, 1)}
return encoder
}

Expand Down Expand Up @@ -61,13 +62,22 @@ func (decoder *dereferenceDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {

type OptionalEncoder struct {
ValueEncoder ValEncoder
seen map[unsafe.Pointer]bool
}

func (encoder *OptionalEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
if *((*unsafe.Pointer)(ptr)) == nil {
ptr = *((*unsafe.Pointer)(ptr))
if encoder.seen[ptr] {
stream.Error = ErrEncounterCycle
return
}
encoder.seen[ptr] = true
defer delete(encoder.seen, ptr)

if ptr == nil {
stream.WriteNil()
} else {
encoder.ValueEncoder.Encode(*((*unsafe.Pointer)(ptr)), stream)
encoder.ValueEncoder.Encode(ptr, stream)
}
}

Expand Down
5 changes: 3 additions & 2 deletions reflect_struct_encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package jsoniter

import (
"fmt"
"github.com/modern-go/reflect2"
"io"
"reflect"
"unsafe"

"github.com/modern-go/reflect2"
)

func encoderOfStruct(ctx *ctx, typ reflect2.Type) ValEncoder {
Expand Down Expand Up @@ -54,7 +55,7 @@ func createCheckIsEmpty(ctx *ctx, typ reflect2.Type) checkIsEmpty {
kind := typ.Kind()
switch kind {
case reflect.Interface:
return &dynamicEncoder{typ}
return &dynamicEncoder{valType: typ, seen: make(map[unsafe.Pointer]bool, 1)}
case reflect.Struct:
return &structEncoder{typ: typ}
case reflect.Array:
Expand Down
3 changes: 3 additions & 0 deletions stream.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package jsoniter

import (
"errors"
"io"
)

var ErrEncounterCycle = errors.New("encountered a cycle")

// stream is a io.Writer like object, with JSON specific write functions.
// Error is not returned as return value, but stored as Error member on this stream instance.
type Stream struct {
Expand Down
6 changes: 5 additions & 1 deletion value_tests/map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,13 @@ func init() {
"2018-12-14": true
}`,
}, unmarshalCase{
ptr: (*map[customKey]string)(nil),
ptr: (*map[customKey]string)(nil),
input: `{"foo": "bar"}`,
})

selfRecursive := map[string]interface{}{}
selfRecursive["me"] = selfRecursive
marshalSelfRecursiveCases = append(marshalSelfRecursiveCases, selfRecursive)
}

type MyInterface interface {
Expand Down
4 changes: 4 additions & 0 deletions value_tests/slice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,8 @@ func init() {
ptr: (*[]byte)(nil),
input: `"c3ViamVjdHM\/X2Q9MQ=="`,
})

selfRecursive := []interface{}{nil}
selfRecursive[0] = selfRecursive
marshalSelfRecursiveCases = append(marshalSelfRecursiveCases, selfRecursive)
}
4 changes: 4 additions & 0 deletions value_tests/struct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ func init() {
"should not marshal",
},
)

selfRecursive := &structRecursive{}
selfRecursive.Me = selfRecursive
marshalSelfRecursiveCases = append(marshalSelfRecursiveCases, selfRecursive)
}

type StructVarious struct {
Expand Down
19 changes: 17 additions & 2 deletions value_tests/value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package test
import (
"encoding/json"
"fmt"
"github.com/json-iterator/go"
"testing"

jsoniter "github.com/json-iterator/go"
"github.com/modern-go/reflect2"
"github.com/stretchr/testify/require"
"testing"
)

type unmarshalCase struct {
Expand All @@ -22,6 +23,8 @@ var marshalCases = []interface{}{
nil,
}

var marshalSelfRecursiveCases = []interface{}{}

type selectedMarshalCase struct {
marshalCase interface{}
}
Expand Down Expand Up @@ -78,3 +81,15 @@ func Test_marshal(t *testing.T) {
})
}
}

func Test_marshal_self_recursive(t *testing.T) {
for i, testCase := range marshalSelfRecursiveCases {
t.Run(fmt.Sprintf("[%v]%s", i, reflect2.TypeOf(testCase).String()), func(t *testing.T) {
should := require.New(t)
_, err1 := json.Marshal(testCase)
should.ErrorContains(err1, "encountered a cycle")
_, err2 := jsoniter.ConfigCompatibleWithStandardLibrary.Marshal(testCase)
should.ErrorContains(err2, "encountered a cycle")
})
}
}

0 comments on commit 6f40716

Please sign in to comment.