Skip to content

Commit

Permalink
Merge pull request #17 from patrickdappollonio/new-functions
Browse files Browse the repository at this point in the history
Add shuffle, first, last.
  • Loading branch information
patrickdappollonio authored Oct 4, 2022
2 parents f7e01da + 71bb5ad commit 660098e
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 0 deletions.
89 changes: 89 additions & 0 deletions template_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ func getTemplateFunctions(virtualKV map[string]string, strict bool) template.Fun
// Go built-ins
"lowercase": strings.ToLower,
"lower": strings.ToLower,
"tolower": strings.ToLower,
"uppercase": strings.ToUpper,
"upper": strings.ToUpper,
"toupper": strings.ToUpper,
"title": cases.Title,
"sprintf": fmt.Sprintf,
"printf": fmt.Sprintf,
Expand Down Expand Up @@ -75,6 +77,9 @@ func getTemplateFunctions(virtualKV map[string]string, strict bool) template.Fun
"slice": slice,
"after": after,
"skip": after,
"shuffle": shuffle,
"first": first,
"last": last,
}
}

Expand Down Expand Up @@ -326,6 +331,90 @@ func after(index any, seq any) (any, error) {
return seqv.Slice(indexv, seqv.Len()).Interface(), nil
}

func shuffle(seq any) (any, error) {
if seq == nil {
return nil, errors.New("seq must be provided")
}

seqv := reflect.ValueOf(seq)
seqv, isNil := indirectValue(seqv)
if isNil {
return nil, errors.New("can't iterate over a nil value")
}

if seqv.Len() == 0 {
return nil, errors.New("can't shuffle an empty sequence")
}

switch seqv.Kind() {
case reflect.Array, reflect.Slice, reflect.String:
// skip
default:
return nil, errors.New("can't iterate over " + reflect.ValueOf(seq).Type().String())
}

shuffled := reflect.MakeSlice(reflect.TypeOf(seq), seqv.Len(), seqv.Len())

rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
randomIndices := rnd.Perm(seqv.Len())

for index, value := range randomIndices {
shuffled.Index(value).Set(seqv.Index(index))
}

return shuffled.Interface(), nil
}

func first(seq any) (any, error) {
if seq == nil {
return nil, errors.New("seq must be provided")
}

seqv := reflect.ValueOf(seq)
seqv, isNil := indirectValue(seqv)
if isNil {
return nil, errors.New("can't iterate over a nil value")
}

switch seqv.Kind() {
case reflect.Array, reflect.Slice, reflect.String:
// okay
default:
return nil, errors.New("can't iterate over " + reflect.ValueOf(seq).Type().String())
}

if seqv.Len() == 0 {
return nil, errors.New("can't get first item of an empty sequence")
}

return seqv.Index(0).Interface(), nil
}

func last(seq any) (any, error) {
if seq == nil {
return nil, errors.New("seq must be provided")
}

seqv := reflect.ValueOf(seq)
seqv, isNil := indirectValue(seqv)
if isNil {
return nil, errors.New("can't iterate over a nil value")
}

switch seqv.Kind() {
case reflect.Array, reflect.Slice, reflect.String:
// okay
default:
return nil, errors.New("can't iterate over " + reflect.ValueOf(seq).Type().String())
}

if seqv.Len() == 0 {
return nil, errors.New("can't get last item of an empty sequence")
}

return seqv.Index(seqv.Len() - 1).Interface(), nil
}

func toInt(v interface{}) (int, bool) {
switch v := v.(type) {
case int:
Expand Down
120 changes: 120 additions & 0 deletions template_functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,123 @@ func Test_after(t *testing.T) {
})
}
}

func Test_shuffle(t *testing.T) {
tests := []struct {
name string
seq any
wantErr bool
}{
{
name: "shuffle 1 2 3",
seq: []int{1, 2, 3},
},
{
name: "nil",
seq: nil,
wantErr: true,
},
{
name: "empty",
seq: []int{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := shuffle(tt.seq)
if (err != nil) != tt.wantErr {
t.Errorf("shuffle() error = %v, wantErr %v", err, tt.wantErr)
return
}

var l1, l2 int

if got != nil {
l1 = reflect.ValueOf(got).Len()
}

if tt.seq != nil {
l2 = reflect.ValueOf(tt.seq).Len()
}

if !tt.wantErr && l1 != l2 {
t.Errorf("shuffle() got length = %d (original: %d)", l1, l2)
}
})
}
}

func Test_first(t *testing.T) {
tests := []struct {
name string
seq any
want any
wantErr bool
}{
{
name: "first 1 2 3",
seq: []int{1, 2, 3},
want: 1,
},
{
name: "nil",
seq: nil,
wantErr: true,
},
{
name: "empty",
seq: []int{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := first(tt.seq)
if (err != nil) != tt.wantErr {
t.Errorf("first() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("first() = %v, want %v", got, tt.want)
}
})
}
}

func Test_last(t *testing.T) {
tests := []struct {
name string
seq any
want any
wantErr bool
}{
{
name: "last 1 2 3",
seq: []int{1, 2, 3},
want: 3,
},
{
name: "nil",
seq: nil,
wantErr: true,
},
{
name: "empty",
seq: []int{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := last(tt.seq)
if (err != nil) != tt.wantErr {
t.Errorf("last() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("last() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit 660098e

Please sign in to comment.