diff --git a/courier/courier.go b/courier/courier.go index 1598344c3d2a..b42a580fc061 100644 --- a/courier/courier.go +++ b/courier/courier.go @@ -65,17 +65,12 @@ func NewCourier(ctx context.Context, deps Dependencies) (Courier, error) { if err != nil { return nil, err } - - expBackoff := backoff.NewExponentialBackOff() - // never stop retrying - expBackoff.MaxElapsedTime = 0 - return &courier{ smsClient: newSMS(ctx, deps), smtpClient: smtp, httpClient: newHTTP(ctx, deps), deps: deps, - backoff: expBackoff, + backoff: backoff.NewExponentialBackOff(), }, nil } @@ -84,29 +79,36 @@ func (c *courier) FailOnDispatchError() { } func (c *courier) Work(ctx context.Context) error { - wait := c.deps.CourierConfig().CourierWorkerPullWait(ctx) - for { - select { - case <-ctx.Done(): - if errors.Is(ctx.Err(), context.Canceled) { - return nil - } - return ctx.Err() - case <-time.After(wait): - if err := backoff.Retry(func() error { - if err := c.DispatchQueue(ctx); err != nil { - return err - } - // when we succeed, we want to reset the backoff - c.backoff.Reset() - return nil - }, c.backoff); err != nil { - return err - } + errChan := make(chan error) + defer close(errChan) + + go c.watchMessages(ctx, errChan) + + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.Canceled) { + return nil } + return ctx.Err() + case err := <-errChan: + return err } } func (c *courier) UseBackoff(b backoff.BackOff) { c.backoff = b } + +func (c *courier) watchMessages(ctx context.Context, errChan chan error) { + wait := c.deps.CourierConfig().CourierWorkerPullWait(ctx) + c.backoff.Reset() + for { + if err := backoff.Retry(func() error { + return c.DispatchQueue(ctx) + }, c.backoff); err != nil { + errChan <- err + return + } + time.Sleep(wait) + } +} diff --git a/courier/courier_dispatcher.go b/courier/courier_dispatcher.go index 33d909e59fb2..8470c024fca4 100644 --- a/courier/courier_dispatcher.go +++ b/courier/courier_dispatcher.go @@ -58,10 +58,13 @@ func (c *courier) DispatchQueue(ctx context.Context) error { messages, err := c.deps.CourierPersister().NextMessages(ctx, uint8(pullCount)) if err != nil { + if errors.Is(err, ErrQueueEmpty) { + return nil + } return err } - for _, msg := range messages { + for k, msg := range messages { if msg.SendCount > maxRetries { if err := c.deps.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusAbandoned); err != nil { c.deps.Logger(). @@ -77,33 +80,41 @@ func (c *courier) DispatchQueue(ctx context.Context) error { WithField("message_id", msg.ID). WithField("message_nid", msg.NID). Warnf(`Message was abandoned because it did not deliver after %d attempts`, msg.SendCount) - continue - } - - if err := c.DispatchMessage(ctx, msg); err != nil { + } else if err := c.DispatchMessage(ctx, msg); err != nil { if err := c.deps.CourierPersister().RecordDispatch(ctx, msg.ID, CourierMessageDispatchStatusFailed, err); err != nil { c.deps.Logger(). WithError(err). WithField("message_id", msg.ID). WithField("message_nid", msg.NID). Error(`Unable to record failure log entry.`) - return err + if c.failOnDispatchError { + return err + } + } + + for _, replace := range messages[k:] { + if err := c.deps.CourierPersister().SetMessageStatus(ctx, replace.ID, MessageStatusQueued); err != nil { + c.deps.Logger(). + WithError(err). + WithField("message_id", replace.ID). + WithField("message_nid", replace.NID). + Error(`Unable to reset the failed message's status to "queued".`) + if c.failOnDispatchError { + return err + } + } } if c.failOnDispatchError { return err } - // an error happened, but we want to ignore it - continue - } - - if err := c.deps.CourierPersister().RecordDispatch(ctx, msg.ID, CourierMessageDispatchStatusSuccess, nil); err != nil { + } else if err := c.deps.CourierPersister().RecordDispatch(ctx, msg.ID, CourierMessageDispatchStatusSuccess, nil); err != nil { c.deps.Logger(). WithError(err). WithField("message_id", msg.ID). WithField("message_nid", msg.NID). Error(`Unable to record success log entry.`) - return err + // continue with execution, as the message was successfully dispatched } } diff --git a/courier/courier_dispatcher_test.go b/courier/courier_dispatcher_test.go index b621072c9997..528badf2de02 100644 --- a/courier/courier_dispatcher_test.go +++ b/courier/courier_dispatcher_test.go @@ -44,16 +44,15 @@ func TestDispatchMessageWithInvalidSMTP(t *testing.T) { t.Run("case=failed sending", func(t *testing.T) { id := queueNewMessage(t, ctx, c, reg) - messages, err := reg.CourierPersister().NextMessages(ctx, 10) + message, err := reg.CourierPersister().LatestQueuedMessage(ctx) require.NoError(t, err) - require.Len(t, messages, 1) - require.Equal(t, id, messages[0].ID) + require.Equal(t, id, message.ID) - err = c.DispatchMessage(ctx, messages[0]) + err = c.DispatchMessage(ctx, *message) // sending the email fails, because there is no SMTP server at foo.url require.Error(t, err) - messages, err = reg.CourierPersister().NextMessages(ctx, 10) + messages, err := reg.CourierPersister().NextMessages(ctx, 10) require.NoError(t, err) require.Len(t, messages, 1) }) diff --git a/courier/handler_test.go b/courier/handler_test.go index 74c47bae5c53..b8920a5bc5f9 100644 --- a/courier/handler_test.go +++ b/courier/handler_test.go @@ -96,7 +96,7 @@ func TestHandler(t *testing.T) { t.Run("case=list messages", func(t *testing.T) { // Arrange test data const msgCount = 10 // total message count - const sentCount = 5 // how many messages' status should be equal to `processing` + const procCount = 5 // how many messages' status should be equal to `processing` const rcptOryCount = 2 // how many messages' recipient should be equal to `noreply@ory.sh` messages := make([]courier.Message, msgCount) @@ -109,8 +109,8 @@ func TestHandler(t *testing.T) { } require.NoError(t, reg.CourierPersister().AddMessage(context.Background(), &messages[i])) } - for i := 0; i < sentCount; i++ { - require.NoError(t, reg.CourierPersister().SetMessageStatus(context.Background(), messages[i].ID, courier.MessageStatusSent)) + for i := 0; i < procCount; i++ { + require.NoError(t, reg.CourierPersister().SetMessageStatus(context.Background(), messages[i].ID, courier.MessageStatusProcessing)) } t.Run("paging", func(t *testing.T) { @@ -146,7 +146,7 @@ func TestHandler(t *testing.T) { for _, tc := range tss { t.Run("endpoint="+tc.name, func(t *testing.T) { parsed := getList(t, tc.name, qs) - assert.Len(t, parsed.Array(), msgCount-sentCount) + assert.Len(t, parsed.Array(), msgCount-procCount) for _, item := range parsed.Array() { assert.Equal(t, "queued", item.Get("status").String()) @@ -154,16 +154,16 @@ func TestHandler(t *testing.T) { }) } }) - t.Run("case=should return all sent messages", func(t *testing.T) { - qs := fmt.Sprintf(`?page_token=%s&page_size=250&status=sent`, defaultPageToken) + t.Run("case=should return all processing messages", func(t *testing.T) { + qs := fmt.Sprintf(`?page_token=%s&page_size=250&status=processing`, defaultPageToken) for _, tc := range tss { t.Run("endpoint="+tc.name, func(t *testing.T) { parsed := getList(t, tc.name, qs) - assert.Len(t, parsed.Array(), sentCount) + assert.Len(t, parsed.Array(), procCount) for _, item := range parsed.Array() { - assert.Equal(t, "sent", item.Get("status").String()) + assert.Equal(t, "processing", item.Get("status").String()) } }) } diff --git a/courier/persistence.go b/courier/persistence.go index 6bb2d4fa6230..4e5834f7faca 100644 --- a/courier/persistence.go +++ b/courier/persistence.go @@ -5,9 +5,9 @@ package courier import ( "context" - "errors" "github.com/gofrs/uuid" + "github.com/pkg/errors" "github.com/ory/x/pagination/keysetpagination" ) @@ -22,6 +22,8 @@ type ( SetMessageStatus(context.Context, uuid.UUID, MessageStatus) error + LatestQueuedMessage(ctx context.Context) (*Message, error) + IncrementMessageSendCount(context.Context, uuid.UUID) error // ListMessages lists all messages in the store given the page, itemsPerPage, status and recipient. diff --git a/courier/test/persistence.go b/courier/test/persistence.go index 2a6e33650c84..fad80efe1742 100644 --- a/courier/test/persistence.go +++ b/courier/test/persistence.go @@ -6,11 +6,10 @@ package test import ( "context" "errors" + "fmt" "testing" "time" - "github.com/ory/x/pointerx" - "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/tidwall/gjson" @@ -39,54 +38,85 @@ func TestPersister(ctx context.Context, newNetworkUnlessExisting NetworkWrapper, t.Run("case=no messages in queue", func(t *testing.T) { m, err := p.NextMessages(ctx, 10) - require.NoError(t, err) + require.ErrorIs(t, err, courier.ErrQueueEmpty) assert.Len(t, m, 0) + + _, err = p.LatestQueuedMessage(ctx) + require.ErrorIs(t, err, courier.ErrQueueEmpty) }) messages := make([]courier.Message, 5) t.Run("case=add messages to the queue", func(t *testing.T) { - start := time.Now().UTC() for k := range messages { - pop.SetNowFunc(func() time.Time { - return start.Add(time.Duration(k) * time.Second) - }) require.NoError(t, faker.FakeData(&messages[k])) require.NoError(t, p.AddMessage(ctx, &messages[k])) + time.Sleep(time.Second) // wait a bit so that the timestamp ordering works in MySQL. } - pop.SetNowFunc(time.Now) }) - t.Run("case=get queued messages", func(t *testing.T) { - actual, err := p.NextMessages(ctx, 10) + t.Run("case=latest message in queue", func(t *testing.T) { + expected, err := p.LatestQueuedMessage(ctx) require.NoError(t, err) - assert.ElementsMatch(t, messages, actual) + + actual := messages[len(messages)-1] + assert.Equal(t, expected.ID, actual.ID) + assert.Equal(t, expected.Subject, actual.Subject) }) - t.Run("case=setting message status", func(t *testing.T) { - require.NoError(t, p.SetMessageStatus(ctx, messages[0].ID, courier.MessageStatusSent)) - require.NoError(t, p.SetMessageStatus(ctx, messages[1].ID, courier.MessageStatusAbandoned)) - require.NoError(t, p.SetMessageStatus(ctx, messages[2].ID, courier.MessageStatusQueued)) + t.Run("case=pull messages from the queue", func(t *testing.T) { + for k, expected := range messages { + expected.Status = courier.MessageStatusProcessing + t.Run(fmt.Sprintf("message=%d", k), func(t *testing.T) { + messages, err := p.NextMessages(ctx, 1) + require.NoError(t, err) + require.Len(t, messages, 1) + + actual := messages[0] + assert.Equal(t, expected.ID, actual.ID) + assert.Equal(t, expected.Subject, actual.Subject) + assert.Equal(t, expected.Body, actual.Body) + assert.Equal(t, expected.Status, actual.Status) + assert.Equal(t, expected.Type, actual.Type) + assert.Equal(t, expected.Recipient, actual.Recipient) + }) + } - ms, err := p.NextMessages(ctx, 10) - require.NoError(t, err) - assert.ElementsMatch(t, messages[2:], ms) + _, err := p.NextMessages(ctx, 10) + require.ErrorIs(t, err, courier.ErrQueueEmpty) + }) + t.Run("case=setting message status", func(t *testing.T) { require.NoError(t, p.SetMessageStatus(ctx, messages[0].ID, courier.MessageStatusQueued)) - require.NoError(t, p.SetMessageStatus(ctx, messages[1].ID, courier.MessageStatusQueued)) + ms, err := p.NextMessages(ctx, 1) + require.NoError(t, err) + require.Len(t, ms, 1) + assert.Equal(t, messages[0].ID, ms[0].ID) + + require.NoError(t, p.SetMessageStatus(ctx, messages[0].ID, courier.MessageStatusSent)) + _, err = p.NextMessages(ctx, 1) + require.ErrorIs(t, err, courier.ErrQueueEmpty) + + require.NoError(t, p.SetMessageStatus(ctx, messages[0].ID, courier.MessageStatusAbandoned)) + _, err = p.NextMessages(ctx, 1) + require.ErrorIs(t, err, courier.ErrQueueEmpty) }) t.Run("case=incrementing send count", func(t *testing.T) { originalSendCount := messages[0].SendCount + require.NoError(t, p.SetMessageStatus(ctx, messages[0].ID, courier.MessageStatusQueued)) require.NoError(t, p.IncrementMessageSendCount(ctx, messages[0].ID)) - message, err := p.FetchMessage(ctx, messages[0].ID) + ms, err := p.NextMessages(ctx, 1) require.NoError(t, err) - assert.Equal(t, originalSendCount+1, message.SendCount) + require.Len(t, ms, 1) + assert.Equal(t, messages[0].ID, ms[0].ID) + assert.Equal(t, originalSendCount+1, ms[0].SendCount) }) t.Run("case=list messages", func(t *testing.T) { + status := courier.MessageStatusProcessing filter := courier.ListCourierMessagesParameters{ - Status: pointerx.Ptr(courier.MessageStatusQueued), + Status: &status, } ms, total, _, err := p.ListMessages(ctx, filter, []keysetpagination.Option{}) @@ -111,24 +141,24 @@ func TestPersister(ctx context.Context, newNetworkUnlessExisting NetworkWrapper, nid2, p2 := newNetwork(t, ctx) const timeFormat = "2006-01-02 15:04:05.99999" msg1 := courier.Message{ - ID: uuid.Must(uuid.FromString("10000000-0000-0000-0000-000000000000")), + ID: uuid.FromStringOrNil("10000000-0000-0000-0000-000000000000"), NID: nid1, - Status: courier.MessageStatusQueued, + Status: courier.MessageStatusProcessing, } err = p1.GetConnection(ctx).Create(&msg1) require.NoError(t, err) msg2 := courier.Message{ - ID: uuid.Must(uuid.FromString("20000000-0000-0000-0000-000000000000")), + ID: uuid.FromStringOrNil("20000000-0000-0000-0000-000000000000"), NID: nid1, - Status: courier.MessageStatusQueued, + Status: courier.MessageStatusProcessing, } err = p1.GetConnection(ctx).Create(&msg2) require.NoError(t, err) msg3 := courier.Message{ - ID: uuid.Must(uuid.FromString("30000000-0000-0000-0000-000000000000")), + ID: uuid.FromStringOrNil("30000000-0000-0000-0000-000000000000"), NID: nid2, - Status: courier.MessageStatusQueued, + Status: courier.MessageStatusProcessing, } err = p2.GetConnection(ctx).Create(&msg3) require.NoError(t, err) @@ -163,8 +193,16 @@ func TestPersister(ctx context.Context, newNetworkUnlessExisting NetworkWrapper, assert.EqualValues(t, nid, expected.NID) assert.EqualValues(t, nid, p.NetworkID(ctx)) - actual, err := p.FetchMessage(ctx, expected.ID) + actual, err := p.LatestQueuedMessage(ctx) require.NoError(t, err) + assert.EqualValues(t, expected.ID, actual.ID) + assert.EqualValues(t, nid, actual.NID) + + actuals, err := p.NextMessages(ctx, 255) + require.NoError(t, err) + + actual = &actuals[0] + assert.EqualValues(t, expected.ID, actual.ID) assert.EqualValues(t, nid, actual.NID) }) @@ -178,25 +216,32 @@ func TestPersister(ctx context.Context, newNetworkUnlessExisting NetworkWrapper, assert.EqualValues(t, nid, expected.NID) assert.EqualValues(t, nid, p.NetworkID(ctx)) - actual, err := p.FetchMessage(ctx, id) + actual, err := p.LatestQueuedMessage(ctx) + require.NoError(t, err) + assert.EqualValues(t, id, actual.ID) + assert.EqualValues(t, nid, actual.NID) + + actuals, err := p.NextMessages(ctx, 255) require.NoError(t, err) + + actual = &actuals[0] + assert.EqualValues(t, id, actual.ID) assert.EqualValues(t, nid, actual.NID) }) t.Run("can not get on another network", func(t *testing.T) { _, p := newNetwork(t, ctx) - actual, err := p.NextMessages(ctx, 255) - require.NoError(t, err) - assert.Len(t, actual, 0) + _, err := p.LatestQueuedMessage(ctx) + require.ErrorIs(t, err, courier.ErrQueueEmpty) - _, err = p.FetchMessage(ctx, id) - assert.ErrorIs(t, err, sqlcon.ErrNoRows) + _, err = p.NextMessages(ctx, 255) + require.ErrorIs(t, err, courier.ErrQueueEmpty) }) t.Run("can not update on another network", func(t *testing.T) { _, p := newNetwork(t, ctx) - err := p.SetMessageStatus(ctx, id, courier.MessageStatusAbandoned) + err := p.SetMessageStatus(ctx, id, courier.MessageStatusProcessing) require.ErrorIs(t, err, sqlcon.ErrNoRows) }) diff --git a/persistence/sql/persister_courier.go b/persistence/sql/persister_courier.go index c706f8d547b1..437d9132e1d0 100644 --- a/persistence/sql/persister_courier.go +++ b/persistence/sql/persister_courier.go @@ -5,8 +5,10 @@ package sql import ( "context" + "database/sql" "encoding/json" + "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/pkg/errors" @@ -16,6 +18,7 @@ import ( "github.com/ory/x/uuidx" "github.com/ory/kratos/courier" + "github.com/ory/kratos/persistence/sql/update" ) var _ courier.Persister = new(Persister) @@ -67,18 +70,66 @@ func (p *Persister) NextMessages(ctx context.Context, limit uint8) (messages []c ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.NextMessages") defer span.End() - if err := p.Connection(ctx). + if err := p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { + var m []courier.Message + if err := tx. + Where("nid = ? AND status = ?", + p.NetworkID(ctx), + courier.MessageStatusQueued, + ). + Order("created_at ASC"). + Limit(int(limit)). + All(&m); err != nil { + return err + } + + if len(m) == 0 { + return sql.ErrNoRows + } + + for i := range m { + message := &m[i] + message.Status = courier.MessageStatusProcessing + if err := update.Generic(ctx, p.GetConnection(ctx), p.r.Tracer(ctx).Tracer(), message, "status"); err != nil { + return err + } + } + + messages = m + return nil + }); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, errors.WithStack(courier.ErrQueueEmpty) + } + return nil, sqlcon.HandleError(err) + } + + if len(messages) == 0 { + return nil, errors.WithStack(courier.ErrQueueEmpty) + } + + return messages, nil +} + +func (p *Persister) LatestQueuedMessage(ctx context.Context) (*courier.Message, error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.LatestQueuedMessage") + defer span.End() + + var m courier.Message + if err := p.GetConnection(ctx). Where("nid = ? AND status = ?", p.NetworkID(ctx), courier.MessageStatusQueued, ). - Order("created_at ASC"). - Limit(int(limit)). - All(&messages); err != nil { + Order("created_at DESC"). + First(&m); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, errors.WithStack(courier.ErrQueueEmpty) + } return nil, sqlcon.HandleError(err) } - return messages, nil + return &m, nil } func (p *Persister) SetMessageStatus(ctx context.Context, id uuid.UUID, ms courier.MessageStatus) error { diff --git a/selfservice/hook/verification_test.go b/selfservice/hook/verification_test.go index 8a60a2dfeb9d..652df7bec61c 100644 --- a/selfservice/hook/verification_test.go +++ b/selfservice/hook/verification_test.go @@ -11,7 +11,6 @@ import ( "time" "github.com/ory/kratos/courier" - "github.com/ory/kratos/internal/testhelpers" "github.com/stretchr/testify/assert" @@ -100,9 +99,6 @@ func TestVerifier(t *testing.T) { require.NoError(t, err) require.Len(t, messages, 2) - require.NoError(t, reg.CourierPersister().SetMessageStatus(context.Background(), messages[0].ID, courier.MessageStatusSent)) - require.NoError(t, reg.CourierPersister().SetMessageStatus(context.Background(), messages[1].ID, courier.MessageStatusSent)) - recipients := make([]string, len(messages)) for k, m := range messages { recipients[k] = m.Recipient @@ -135,7 +131,7 @@ func TestVerifier(t *testing.T) { require.NoError(t, err) messages, err = reg.CourierPersister().NextMessages(context.Background(), 12) - require.NoError(t, err) + require.EqualError(t, err, courier.ErrQueueEmpty.Error()) assert.Len(t, messages, 0) }) } diff --git a/selfservice/strategy/code/code_sender_test.go b/selfservice/strategy/code/code_sender_test.go index 40ad1e03967e..e5ba75826eb5 100644 --- a/selfservice/strategy/code/code_sender_test.go +++ b/selfservice/strategy/code/code_sender_test.go @@ -65,9 +65,6 @@ func TestSender(t *testing.T) { require.NoError(t, err) require.Len(t, messages, 2) - require.NoError(t, reg.CourierPersister().SetMessageStatus(ctx, messages[0].ID, courier.MessageStatusSent)) - require.NoError(t, reg.CourierPersister().SetMessageStatus(ctx, messages[1].ID, courier.MessageStatusSent)) - assert.EqualValues(t, "tracked@ory.sh", messages[0].Recipient) assert.Contains(t, messages[0].Subject, "Recover access to your account") @@ -93,9 +90,6 @@ func TestSender(t *testing.T) { require.NoError(t, err) require.Len(t, messages, 2) - require.NoError(t, reg.CourierPersister().SetMessageStatus(ctx, messages[0].ID, courier.MessageStatusSent)) - require.NoError(t, reg.CourierPersister().SetMessageStatus(ctx, messages[1].ID, courier.MessageStatusSent)) - assert.EqualValues(t, "tracked@ory.sh", messages[0].Recipient) assert.Equal(t, messages[0].Subject, subject+" valid") assert.Contains(t, messages[0].Body, body) @@ -127,9 +121,6 @@ func TestSender(t *testing.T) { require.NoError(t, err) require.Len(t, messages, 2) - require.NoError(t, reg.CourierPersister().SetMessageStatus(ctx, messages[0].ID, courier.MessageStatusSent)) - require.NoError(t, reg.CourierPersister().SetMessageStatus(ctx, messages[1].ID, courier.MessageStatusSent)) - assert.EqualValues(t, "tracked@ory.sh", messages[0].Recipient) assert.Contains(t, messages[0].Subject, "Please verify your email address") @@ -155,9 +146,6 @@ func TestSender(t *testing.T) { require.NoError(t, err) require.Len(t, messages, 2) - require.NoError(t, reg.CourierPersister().SetMessageStatus(ctx, messages[0].ID, courier.MessageStatusSent)) - require.NoError(t, reg.CourierPersister().SetMessageStatus(ctx, messages[1].ID, courier.MessageStatusSent)) - assert.EqualValues(t, "tracked@ory.sh", messages[0].Recipient) assert.Equal(t, messages[0].Subject, subject+" valid") assert.Contains(t, messages[0].Body, body) @@ -218,7 +206,7 @@ func TestSender(t *testing.T) { messages, err := reg.CourierPersister().NextMessages(ctx, 0) - require.NoError(t, err) + require.ErrorIs(t, err, courier.ErrQueueEmpty) require.Len(t, messages, 0) }) } diff --git a/selfservice/strategy/link/sender.go b/selfservice/strategy/link/sender.go index 4190cf7988a6..d58f167335d7 100644 --- a/selfservice/strategy/link/sender.go +++ b/selfservice/strategy/link/sender.go @@ -5,7 +5,6 @@ package link import ( "context" - stderrors "errors" "net/url" "github.com/hashicorp/go-retryablehttp" @@ -52,7 +51,7 @@ type ( } ) -var ErrUnknownAddress = stderrors.New("verification requested for unknown address") +var ErrUnknownAddress = errors.New("verification requested for unknown address") func NewSender(r senderDependencies) *Sender { return &Sender{r: r} diff --git a/selfservice/strategy/link/sender_test.go b/selfservice/strategy/link/sender_test.go index de0dc767d6da..8fd62e038f3e 100644 --- a/selfservice/strategy/link/sender_test.go +++ b/selfservice/strategy/link/sender_test.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "net/http/httptest" + "sync" "testing" "time" @@ -61,9 +62,6 @@ func TestManager(t *testing.T) { require.NoError(t, err) require.Len(t, messages, 2) - require.NoError(t, reg.CourierPersister().SetMessageStatus(context.Background(), messages[0].ID, courier.MessageStatusSent)) - require.NoError(t, reg.CourierPersister().SetMessageStatus(context.Background(), messages[1].ID, courier.MessageStatusSent)) - assert.EqualValues(t, "tracked@ory.sh", messages[0].Recipient) assert.Contains(t, messages[0].Subject, "Recover access to your account") assert.Contains(t, messages[0].Body, urlx.AppendPaths(conf.SelfServiceLinkMethodBaseURL(ctx), recovery.RouteSubmitFlow).String()+"?") @@ -78,6 +76,8 @@ func TestManager(t *testing.T) { }) t.Run("method=SendRecoveryLink via HTTP", func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(2) type requestBody struct { Recipient string RecoveryURL string `json:"recovery_url"` @@ -92,6 +92,7 @@ func TestManager(t *testing.T) { var message requestBody require.NoError(t, json.Unmarshal(b, &message)) messages = append(messages, &message) + wg.Done() })) t.Cleanup(srv.Close) requestConfig := fmt.Sprintf(`{"url": "%s"}`, srv.URL) @@ -101,6 +102,12 @@ func TestManager(t *testing.T) { cour, err := reg.Courier(ctx) require.NoError(t, err) + ctx, cancel := context.WithCancel(ctx) + defer t.Cleanup(cancel) + go func() { + require.NoError(t, cour.Work(ctx)) + }() + s, err := reg.RecoveryStrategies(ctx).Strategy("link") require.NoError(t, err) f, err := recovery.NewFlow(conf, time.Hour, "", u, s, flow.TypeBrowser) @@ -109,9 +116,9 @@ func TestManager(t *testing.T) { require.NoError(t, reg.RecoveryFlowPersister().CreateRecoveryFlow(context.Background(), f)) require.NoError(t, reg.LinkSender().SendRecoveryLink(context.Background(), f, "email", "tracked@ory.sh")) - require.ErrorIs(t, reg.LinkSender().SendRecoveryLink(context.Background(), f, "email", "not-tracked@ory.sh"), link.ErrUnknownAddress) + require.EqualError(t, reg.LinkSender().SendRecoveryLink(context.Background(), f, "email", "not-tracked@ory.sh"), link.ErrUnknownAddress.Error()) - require.NoError(t, cour.DispatchQueue(ctx)) + wg.Wait() assert.EqualValues(t, "tracked@ory.sh", messages[0].To) assert.Contains(t, messages[0].Subject, "Recover access to your account") @@ -132,14 +139,11 @@ func TestManager(t *testing.T) { require.NoError(t, reg.VerificationFlowPersister().CreateVerificationFlow(context.Background(), f)) require.NoError(t, reg.LinkSender().SendVerificationLink(context.Background(), f, "email", "tracked@ory.sh")) - require.ErrorIs(t, reg.LinkSender().SendVerificationLink(context.Background(), f, "email", "not-tracked@ory.sh"), link.ErrUnknownAddress) + require.EqualError(t, reg.LinkSender().SendVerificationLink(context.Background(), f, "email", "not-tracked@ory.sh"), link.ErrUnknownAddress.Error()) messages, err := reg.CourierPersister().NextMessages(context.Background(), 12) require.NoError(t, err) require.Len(t, messages, 2) - require.NoError(t, reg.CourierPersister().SetMessageStatus(context.Background(), messages[0].ID, courier.MessageStatusSent)) - require.NoError(t, reg.CourierPersister().SetMessageStatus(context.Background(), messages[1].ID, courier.MessageStatusSent)) - assert.EqualValues(t, "tracked@ory.sh", messages[0].Recipient) assert.Contains(t, messages[0].Subject, "Please verify") assert.Contains(t, messages[0].Body, urlx.AppendPaths(conf.SelfServiceLinkMethodBaseURL(ctx), verification.RouteSubmitFlow).String()+"?") @@ -203,7 +207,7 @@ func TestManager(t *testing.T) { messages, err := reg.CourierPersister().NextMessages(context.Background(), 0) - require.NoError(t, err) + require.ErrorIs(t, err, courier.ErrQueueEmpty) require.Len(t, messages, 0) }) } diff --git a/selfservice/strategy/profile/strategy_test.go b/selfservice/strategy/profile/strategy_test.go index d489e344a6ce..f67407fe799f 100644 --- a/selfservice/strategy/profile/strategy_test.go +++ b/selfservice/strategy/profile/strategy_test.go @@ -17,8 +17,6 @@ import ( "testing" "time" - "github.com/ory/kratos/courier" - "github.com/ory/x/jsonx" kratos "github.com/ory/kratos/internal/httpclient" @@ -533,12 +531,9 @@ func TestStrategyTraits(t *testing.T) { assert.EqualValues(t, flow.StateSuccess, gjson.Get(actual, "state").String(), "%s", actual) assert.Equal(t, newEmail, gjson.Get(actual, "ui.nodes.#(attributes.name==traits.email).attributes.value").Value(), "%s", actual) - ms, err := reg.CourierPersister().NextMessages(ctx, 10) + m, err := reg.CourierPersister().LatestQueuedMessage(context.Background()) require.NoError(t, err) - require.Len(t, ms, 1) - assert.Contains(t, ms[0].Subject, "verify your email address") - - require.NoError(t, reg.CourierPersister().SetMessageStatus(ctx, ms[0].ID, courier.MessageStatusSent)) + assert.Contains(t, m.Subject, "verify your email address") } payload := func(newEmail string) func(v url.Values) { @@ -555,13 +550,13 @@ func TestStrategyTraits(t *testing.T) { }) t.Run("type=spa", func(t *testing.T) { - newEmail := "update-verify-browser-1@mail.com" + newEmail := "update-verify-browser@mail.com" actual := expectSuccess(t, false, true, browserUser1, payload(newEmail)) check(t, actual, newEmail) }) t.Run("type=browser", func(t *testing.T) { - newEmail := "update-verify-browser-2@mail.com" + newEmail := "update-verify-browser@mail.com" actual := expectSuccess(t, false, false, browserUser1, payload(newEmail)) check(t, actual, newEmail) })