diff --git a/std/msgfilter_test.go b/std/msgfilter_test.go index 29925e1f..347d4ece 100644 --- a/std/msgfilter_test.go +++ b/std/msgfilter_test.go @@ -4,6 +4,7 @@ package std import ( "context" + "fmt" "testing" "github.com/andersfylling/disgord" @@ -46,55 +47,75 @@ func TestNewMsgFilter(t *testing.T) { func TestMsgFilter_NotByBot(t *testing.T) { var botID disgord.Snowflake = 123 - filter, _ := newMsgFilter(context.Background(), &clientRESTMock{id: botID}) - evt := &disgord.MessageCreate{ - Message: &disgord.Message{ - Author: &disgord.User{Bot: true}, - }, + + messageFromBot := &disgord.Message{ + Author: &disgord.User{Bot: true}, } - result := filter.NotByBot(evt) - if result != nil { - t.Error("expected a match") + messageNotFromBot := &disgord.Message{ + Author: &disgord.User{Bot: false}, } - evt.Message.Author.Bot = false - result = filter.NotByBot(evt) - if result == nil { - t.Error("expected pass-through") + testCases := []struct { + name string + evt interface{} + shouldPassThrough bool + }{ + {"MessageCreate_FromBot", &disgord.MessageCreate{Message: messageFromBot}, false}, + {"MessageUpdate_FromBot", &disgord.MessageUpdate{Message: messageFromBot}, false}, + {"MessageCreate_NotBot", &disgord.MessageCreate{Message: messageNotFromBot}, true}, + {"MessageUpdate_NotBot", &disgord.MessageUpdate{Message: messageNotFromBot}, true}, } - evt.Message.Author = nil - result = filter.NotByBot(evt) - if result == nil { - t.Error("expected pass-through") + filter, _ := newMsgFilter(context.Background(), &clientRESTMock{id: botID}) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := filter.NotByBot(tc.evt) + if tc.shouldPassThrough && result == nil { + t.Error("expected to passthrough") + } + if !tc.shouldPassThrough && result != nil { + t.Error("expected a filter match") + } + }) } } func TestMsgFilter_IsByBot(t *testing.T) { var botID disgord.Snowflake = 123 - filter, _ := newMsgFilter(context.Background(), &clientRESTMock{id: botID}) - evt := &disgord.MessageCreate{ - Message: &disgord.Message{ - Author: &disgord.User{Bot: false}, - }, + + messageFromBot := &disgord.Message{ + Author: &disgord.User{Bot: true}, } - result := filter.IsByBot(evt) - if result != nil { - t.Error("expected a match") + messageNotFromBot := &disgord.Message{ + Author: &disgord.User{Bot: false}, } - evt.Message.Author.Bot = true - result = filter.IsByBot(evt) - if result == nil { - t.Error("expected pass-through") + testCases := []struct { + name string + evt interface{} + shouldPassThrough bool + }{ + {"MessageCreate_FromBot", &disgord.MessageCreate{Message: messageFromBot}, true}, + {"MessageUpdate_FromBot", &disgord.MessageUpdate{Message: messageFromBot}, true}, + {"MessageCreate_NotBot", &disgord.MessageCreate{Message: messageNotFromBot}, false}, + {"MessageUpdate_NotBot", &disgord.MessageUpdate{Message: messageNotFromBot}, false}, } - evt.Message.Author = nil - result = filter.IsByBot(evt) - if result == nil { - t.Error("expected pass-through") + filter, _ := newMsgFilter(context.Background(), &clientRESTMock{id: botID}) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := filter.IsByBot(tc.evt) + if tc.shouldPassThrough && result == nil { + t.Error("expected to passthrough") + } + if !tc.shouldPassThrough && result != nil { + t.Error("expected a filter match") + } + }) } } @@ -180,49 +201,137 @@ func TestMsgFilter_IsByWebhook(t *testing.T) { func TestMsgFilter_ContainsBotMention(t *testing.T) { var botID disgord.Snowflake = 123 - filter, _ := newMsgFilter(context.Background(), &clientRESTMock{id: botID}) - var evt interface{} - e := &disgord.MessageCreate{ - Message: &disgord.Message{Content: "<@" + botID.String() + "> hello"}, + var wrongBotID disgord.Snowflake = 126 + + messageCreate := func(content string) interface{} { + return &disgord.MessageCreate{ + Message: &disgord.Message{Content: content}, + } } - evt = e - result := filter.ContainsBotMention(evt) - if result == nil { - t.Error("expected to find a match") + messageUpdate := func(content string) interface{} { + return &disgord.MessageUpdate{ + Message: &disgord.Message{Content: content}, + } } - e.Message.Content = "diff prefix " + e.Message.Content - result = filter.ContainsBotMention(evt) - if result == nil { - t.Error("expected to find a match") + testCases := []struct { + name string + evt interface{} + shouldPassThrough bool + }{ + { + "MessageCreate_BotHello", + messageCreate(fmt.Sprintf("<@%s> hello", botID.String())), + true, + }, + { + "MessageUpdate_BotHello", + messageUpdate(fmt.Sprintf("<@%s> hello", botID.String())), + true, + }, + { + "MessageCreate_WrongBotHello", + messageCreate(fmt.Sprintf("<@%s> hello", wrongBotID.String())), + false, + }, + { + "MessageUpdate_WrongBotHello", + messageUpdate(fmt.Sprintf("<@%s> hello", wrongBotID.String())), + false, + }, + { + "MessageCreate_BotHellWithPrefix", + messageCreate(fmt.Sprintf("diff prefix <@%s> hello", botID.String())), + true, + }, + { + "MessageUpdate_BotHelloWithPrefix", + messageUpdate(fmt.Sprintf("diff prefix <@%s> hello", botID.String())), + true, + }, } - filter.botID = botID + 3 - result = filter.ContainsBotMention(evt) - if result != nil { - t.Error("did not expect a match") + filter, _ := newMsgFilter(context.Background(), &clientRESTMock{id: botID}) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := filter.ContainsBotMention(tc.evt) + if tc.shouldPassThrough && result == nil { + t.Error("expected to passthrough") + } + if !tc.shouldPassThrough && result != nil { + t.Error("expected a filter match") + } + }) } } func TestMsgFilter_HasBotMentionPrefix(t *testing.T) { var botID disgord.Snowflake = 123 - filter, _ := newMsgFilter(context.Background(), &clientRESTMock{id: botID}) - var evt interface{} - e := &disgord.MessageCreate{ - Message: &disgord.Message{Content: "<@" + botID.String() + "> hello"}, + var wrongBotID disgord.Snowflake = 126 + + messageCreate := func(content string) interface{} { + return &disgord.MessageCreate{ + Message: &disgord.Message{Content: content}, + } } - evt = e - result := filter.HasBotMentionPrefix(evt) - if result == nil { - t.Error("expected to find a match") + messageUpdate := func(content string) interface{} { + return &disgord.MessageUpdate{ + Message: &disgord.Message{Content: content}, + } } - e.Message.Content = "diff prefix " + e.Message.Content - result = filter.HasBotMentionPrefix(evt) - if result != nil { - t.Error("did not expect a match") + testCases := []struct { + name string + evt interface{} + shouldPassThrough bool + }{ + { + "MessageCreate_BotHello", + messageCreate(fmt.Sprintf("<@%s> hello", botID.String())), + true, + }, + { + "MessageUpdate_BotHello", + messageUpdate(fmt.Sprintf("<@%s> hello", botID.String())), + true, + }, + { + "MessageCreate_WrongBotHello", + messageCreate(fmt.Sprintf("<@%s> hello", wrongBotID.String())), + false, + }, + { + "MessageUpdate_WrongBotHello", + messageUpdate(fmt.Sprintf("<@%s> hello", wrongBotID.String())), + false, + }, + { + "MessageCreate_BotHellWithDiffPrefix", + messageCreate(fmt.Sprintf("diff prefix <@%s> hello", botID.String())), + false, + }, + { + "MessageUpdate_BotHelloWithDiffPrefix", + messageUpdate(fmt.Sprintf("diff prefix <@%s> hello", botID.String())), + false, + }, + } + + filter, _ := newMsgFilter(context.Background(), &clientRESTMock{id: botID}) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := filter.HasBotMentionPrefix(tc.evt) + if tc.shouldPassThrough && result == nil { + t.Error("expected to passthrough") + } + if !tc.shouldPassThrough && result != nil { + t.Error("expected a filter match") + } + }) } } @@ -239,44 +348,80 @@ func TestMsgFilter_SetPrefix(t *testing.T) { } func TestMsgFilter_HasPrefix(t *testing.T) { + messageCreate := func(content string) interface{} { + return &disgord.MessageCreate{ + Message: &disgord.Message{Content: content}, + } + } + + messageUpdate := func(content string) interface{} { + return &disgord.MessageUpdate{ + Message: &disgord.Message{Content: content}, + } + } + prefix := "!!" + testCases := []struct { + name string + evt interface{} + shouldPassThrough bool + }{ + {"MessageCreate_CorrectPrefix", messageCreate(prefix + "hello"), true}, + {"MessageUpdate_CorrectPrefix", messageUpdate(prefix + "hello"), true}, + {"MessageCreate_WrongPrefix", messageCreate("diff prefix " + prefix + "hello"), false}, + {"MessageUpdate_WrongPrefix", messageUpdate("diff prefix " + prefix + "hello"), false}, + } + filter, _ := newMsgFilter(context.Background(), &clientRESTMock{}) filter.SetPrefix(prefix) - var evt interface{} - e := &disgord.MessageCreate{ - Message: &disgord.Message{Content: prefix + "hello"}, + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := filter.HasPrefix(tc.evt) + if tc.shouldPassThrough && result == nil { + t.Error("expected to passthrough") + } + if !tc.shouldPassThrough && result != nil { + t.Error("expected a filter match") + } + }) } - evt = e +} - result := filter.HasPrefix(evt) - if result == nil { - t.Error("expected to find a match") +func TestMsgFilter_StripPrefix(t *testing.T) { + messageCreate := func(content string) interface{} { + return &disgord.MessageCreate{ + Message: &disgord.Message{Content: content}, + } } - e.Message.Content = "diff prefix " + e.Message.Content - result = filter.HasBotMentionPrefix(evt) - if result != nil { - t.Error("did not expect a match") + messageUpdate := func(content string) interface{} { + return &disgord.MessageUpdate{ + Message: &disgord.Message{Content: content}, + } } -} -func TestMsgFilter_StripPrefix(t *testing.T) { prefix := "!!" + testCases := []struct { + name string + evt interface{} + }{ + {"MessageCreate", messageCreate(prefix + "hello")}, + {"MessageUpdate", messageUpdate(prefix + "hello")}, + } + filter, _ := newMsgFilter(context.Background(), &clientRESTMock{}) filter.SetPrefix(prefix) - var evt interface{} - e := &disgord.MessageCreate{ - Message: &disgord.Message{Content: prefix + "hello"}, - } - evt = e - - result := filter.StripPrefix(evt) - if result == nil { - t.Error("expected prefix stripping to work") - } - if filter.HasPrefix(evt) != nil { - t.Error("did not strip prefix off message") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := filter.StripPrefix(tc.evt) + if result == nil { + t.Error("expected to passthrough") + } + if filter.HasPrefix(tc.evt) != nil { + t.Error("Did not strip prefix off message") + } + }) } }