Skip to content

Commit

Permalink
Fix Panic During Person Unmarshal (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
bbengfort authored Sep 9, 2024
1 parent 60aeda1 commit cb6ce3e
Show file tree
Hide file tree
Showing 6 changed files with 353 additions and 91 deletions.
170 changes: 170 additions & 0 deletions pkg/ivms101/db_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package ivms101_test

import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"os"
reflect "reflect"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -84,6 +88,172 @@ func TestIVMS101Database(t *testing.T) {
})
}

func TestIVMS101DatabaseScan(t *testing.T) {
type Object struct {
Identity *ivms101.IdentityPayload `json:"identity"`
Person *ivms101.Person `json:"person"`
NaturalPerson *ivms101.NaturalPerson `json:"naturalPerson"`
LegalPerson *ivms101.LegalPerson `json:"legalPerson"`
Address *ivms101.Address `json:"address"`
}

model := &Object{}
mock := &MockRow{}
err := mock.Open("testdata/identity_payload.json", "testdata/person_legal_person.json", "testdata/natural_person.json", "testdata/legal_person.json", "testdata/address.json")
require.NoError(t, err, "could not load mock fixtures")
require.Len(t, mock.raw, 5)

err = mock.Scan(&model.Identity, &model.Person, &model.NaturalPerson, &model.LegalPerson, &model.Address)
require.NoError(t, err)

require.NotEmpty(t, model.Identity)
require.NotEmpty(t, model.Person)
require.NotEmpty(t, model.NaturalPerson)
require.NotEmpty(t, model.LegalPerson)
require.NotEmpty(t, model.Address)
}

func TestIVMS101DatabaseScanNil(t *testing.T) {
type Object struct {
Identity *ivms101.IdentityPayload `json:"identity"`
Person *ivms101.Person `json:"person"`
NaturalPerson *ivms101.NaturalPerson `json:"naturalPerson"`
LegalPerson *ivms101.LegalPerson `json:"legalPerson"`
Address *ivms101.Address `json:"address"`
}

model := &Object{}
mock := &MockRow{}
err := mock.Open("", "", "", "", "")
require.NoError(t, err, "could not load mock fixtures")
require.Len(t, mock.raw, 5)

err = mock.Scan(&model.Identity, &model.Person, &model.NaturalPerson, &model.LegalPerson, &model.Address)
require.NoError(t, err)

require.Empty(t, model.Identity)
require.Empty(t, model.Person)
require.Empty(t, model.NaturalPerson)
require.Empty(t, model.LegalPerson)
require.Empty(t, model.Address)
}

func TestIVMS101DatabaseScanEmptyBytes(t *testing.T) {
type Object struct {
Identity *ivms101.IdentityPayload `json:"identity"`
Person *ivms101.Person `json:"person"`
NaturalPerson *ivms101.NaturalPerson `json:"naturalPerson"`
LegalPerson *ivms101.LegalPerson `json:"legalPerson"`
Address *ivms101.Address `json:"address"`
}

model := &Object{}
mock := &MockRow{
raw: [][]byte{{}, {}, {}, {}, {}},
}

err := mock.Scan(&model.Identity, &model.Person, &model.NaturalPerson, &model.LegalPerson, &model.Address)
require.EqualError(t, err, "unexpected end of JSON input")
}

func TestIVMS101DatabaseScanNullJSON(t *testing.T) {
type Object struct {
Identity *ivms101.IdentityPayload `json:"identity"`
Person *ivms101.Person `json:"person"`
NaturalPerson *ivms101.NaturalPerson `json:"naturalPerson"`
LegalPerson *ivms101.LegalPerson `json:"legalPerson"`
Address *ivms101.Address `json:"address"`
}

model := &Object{}
mock := &MockRow{
raw: [][]byte{{110, 117, 108, 108}, {110, 117, 108, 108}, {110, 117, 108, 108}, {110, 117, 108, 108}, {110, 117, 108, 108}},
}

err := mock.Scan(&model.Identity, &model.Person, &model.NaturalPerson, &model.LegalPerson, &model.Address)
require.NoError(t, err)
require.Empty(t, model.Identity)
require.Empty(t, model.Person)
require.Empty(t, model.NaturalPerson)
require.Empty(t, model.LegalPerson)
require.Empty(t, model.Address)
}

type MockRow struct {
raw [][]byte
}

var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error

func (r *MockRow) Open(paths ...string) error {
r.raw = make([][]byte, 0, len(paths))
for _, path := range paths {
if err := r.open(path); err != nil {
return err
}
}
return nil
}

func (r *MockRow) open(path string) (err error) {
if path == "" {
r.raw = append(r.raw, nil)
return nil
}

var data []byte
if data, err = os.ReadFile(path); err != nil {
return err
}
r.raw = append(r.raw, data)
return nil
}

func (r *MockRow) Scan(dest ...any) error {
if len(dest) != len(r.raw) {
return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(r.raw), len(dest))
}

for i, raw := range r.raw {
dst := dest[i]

if raw == nil {
dst = nil
continue
}

if scanner, ok := dst.(sql.Scanner); ok {
if err := scanner.Scan(raw); err != nil {
return err
}
}

dpv := reflect.ValueOf(dst)
if dpv.Kind() != reflect.Pointer {
return errors.New("destination not a pointer")
}
if dpv.IsNil() {
return errNilPtr
}

dv := reflect.Indirect(dpv)
switch dv.Kind() {
case reflect.Pointer:
if dst == nil {
dv.SetZero()
}
dv.Set(reflect.New(dv.Type().Elem()))
dvi := dv.Interface()
if scanner, ok := dvi.(sql.Scanner); ok {
if err := scanner.Scan(raw); err != nil {
return err
}
}
}
}
return nil
}

func loadFixture(path string, obj interface{}) (err error) {
var f *os.File
if f, err = os.Open(path); err != nil {
Expand Down
18 changes: 0 additions & 18 deletions pkg/ivms101/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,6 @@ import (
"strings"
)

// Standard error values for error type checking
var (
ErrNoLegalPersonNameIdentifiers = errors.New("one or more legal person name identifiers is required")
ErrInvalidLegalPersonName = errors.New("legal person name required with max length 100 chars")
ErrInvalidCustomerNumber = errors.New("customer number can be at most 50 chars")
ErrInvalidCountryCode = errors.New("invalid ISO-3166-1 alpha-2 country code")
ErrValidNationalIdentifierLegalPerson = errors.New("a legal person must have a national identifier of type RAID, MISC, LEIX, or TXID")
ErrInvalidLEI = errors.New("national identifier required with max length 35")
ErrCompleteNationalIdentifierCountry = errors.New("a legal person must not have a value for country if identifier type is not LEIX")
ErrCompleteNationalIdentifierAuthorityEmpty = errors.New("a legal person must have a value for registration authority if identifier type is not LEIX")
ErrCompleteNationalIdentifierAuthority = errors.New("a legal person must not have a value for registration authority if identifier type is LEIX")
ErrInvalidDateOfBirth = errors.New("date of birth must be a valid date in YYYY-MM-DD format")
ErrInvalidPlaceOfBirth = errors.New("place of birth required with at most 70 characters")
ErrDateInPast = errors.New("date of birth must be a historic date, prior to current date")
ErrValidAddress = errors.New("address must have at least one address line or street name + building name or number")
ErrInvalidAddressLines = errors.New("an address can contain at most 7 address lines")
)

// Parsing and JSON Serialization Errors
var (
ErrPersonOneOfViolation = errors.New("ivms101: person must be either a legal person or a natural person not both")
Expand Down
46 changes: 44 additions & 2 deletions pkg/ivms101/identity.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package ivms101

import "encoding/json"
import (
"bytes"
"encoding/json"
)

//===========================================================================
// IdentityPayload Methods
Expand Down Expand Up @@ -43,6 +46,11 @@ func (i *IdentityPayload) MarshalJSON() ([]byte, error) {
}

func (i *IdentityPayload) UnmarshalJSON(data []byte) (err error) {
if bytes.Equal(data, nullJSON) {
i = nil
return nil
}

// Perform rekeying operation
if allowRekeying {
if data, err = Rekey(data, serialIdentityPayloadFields); err != nil {
Expand Down Expand Up @@ -98,6 +106,11 @@ func (o *Originator) MarshalJSON() ([]byte, error) {
}

func (o *Originator) UnmarshalJSON(data []byte) (err error) {
if bytes.Equal(data, nullJSON) {
o = nil
return nil
}

// Perform rekeying operation
if allowRekeying {
if data, err = Rekey(data, serialOriginatorFields); err != nil {
Expand Down Expand Up @@ -149,6 +162,11 @@ func (b *Beneficiary) MarshalJSON() ([]byte, error) {
}

func (b *Beneficiary) UnmarshalJSON(data []byte) (err error) {
if bytes.Equal(data, nullJSON) {
b = nil
return nil
}

// Perform rekeying operation
if allowRekeying {
if data, err = Rekey(data, serialBeneficiaryFields); err != nil {
Expand Down Expand Up @@ -190,6 +208,11 @@ func (o *OriginatingVasp) MarshalJSON() ([]byte, error) {
}

func (o *OriginatingVasp) UnmarshalJSON(data []byte) (err error) {
if bytes.Equal(data, nullJSON) {
o = nil
return nil
}

// Perform rekeying operation
if allowRekeying {
if data, err = Rekey(data, serialOriginatorVASPFields); err != nil {
Expand Down Expand Up @@ -230,6 +253,11 @@ func (b *BeneficiaryVasp) MarshalJSON() ([]byte, error) {
}

func (b *BeneficiaryVasp) UnmarshalJSON(data []byte) (err error) {
if bytes.Equal(data, nullJSON) {
b = nil
return nil
}

// Perform rekeying operation
if allowRekeying {
if data, err = Rekey(data, serialBeneficiaryVASPFields); err != nil {
Expand Down Expand Up @@ -275,6 +303,11 @@ func (v *IntermediaryVasp) MarshalJSON() ([]byte, error) {
}

func (v *IntermediaryVasp) UnmarshalJSON(data []byte) (err error) {
if bytes.Equal(data, nullJSON) {
v = nil
return nil
}

// Perform rekeying operation
if allowRekeying {
if data, err = Rekey(data, serialIntermediaryVASPFields); err != nil {
Expand Down Expand Up @@ -316,6 +349,11 @@ func (p *TransferPath) MarshalJSON() ([]byte, error) {
}

func (p *TransferPath) UnmarshalJSON(data []byte) (err error) {
if bytes.Equal(data, nullJSON) {
p = nil
return nil
}

// Perform rekeying operation
if allowRekeying {
if data, err = Rekey(data, serialTransferPathFields); err != nil {
Expand Down Expand Up @@ -359,6 +397,11 @@ func (p *PayloadMetadata) MarshalJSON() ([]byte, error) {
}

func (p *PayloadMetadata) UnmarshalJSON(data []byte) (err error) {
if bytes.Equal(data, nullJSON) {
p = nil
return nil
}

// Perform rekeying operation
if allowRekeying {
if data, err = Rekey(data, serialPayloadMetadataFields); err != nil {
Expand All @@ -374,6 +417,5 @@ func (p *PayloadMetadata) UnmarshalJSON(data []byte) (err error) {

// Populate payload metadata values
p.TransliterationMethod = middle.TransliterationMethod

return nil
}
Loading

0 comments on commit cb6ce3e

Please sign in to comment.