From c33769129e113d0e2db1984c522351c48ef41200 Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Fri, 25 Nov 2022 14:41:12 +0100 Subject: [PATCH] fix(GODT-2156): Use UUID instead of Integers for internal ID Avoids collision issues if users try to copy messages between different Gluon based servers. --- .../gluon_bench/store_benchmarks/create.go | 2 +- .../gluon_bench/store_benchmarks/utils.go | 2 +- imap/strong_types.go | 27 ++++--- internal/backend/connector_updates.go | 4 +- internal/backend/state_connector_impl.go | 4 +- internal/backend/state_user_interface_impl.go | 4 - internal/backend/user.go | 16 ---- internal/db/ent/message.go | 10 ++- internal/db/ent/message_create.go | 17 ++--- internal/db/ent/message_delete.go | 2 +- internal/db/ent/message_query.go | 2 +- internal/db/ent/message_update.go | 4 +- internal/db/ent/messageflag.go | 6 +- internal/db/ent/messageflag_create.go | 2 +- internal/db/ent/messageflag_update.go | 8 +- internal/db/ent/migrate/schema.go | 6 +- internal/db/ent/schema/message.go | 2 +- internal/db/ent/uid.go | 6 +- internal/db/ent/uid_create.go | 2 +- internal/db/ent/uid_update.go | 8 +- internal/db/mailbox.go | 14 +++- internal/db/message.go | 8 +- internal/state/actions.go | 2 +- internal/state/snapshot_messages_test.go | 74 +++++++++++-------- internal/state/user_interface.go | 2 - store/store_test.go | 32 +++++--- tests/append_test.go | 8 +- tests/counts_test.go | 2 +- tests/fetch_test.go | 3 +- tests/helper_test.go | 46 ++++++++++++ tests/recovery_mailbox_test.go | 21 ++++-- tests/store_cleanup_test.go | 2 +- 32 files changed, 205 insertions(+), 143 deletions(-) diff --git a/benchmarks/gluon_bench/store_benchmarks/create.go b/benchmarks/gluon_bench/store_benchmarks/create.go index 67e61b5d..7d5de047 100644 --- a/benchmarks/gluon_bench/store_benchmarks/create.go +++ b/benchmarks/gluon_bench/store_benchmarks/create.go @@ -31,7 +31,7 @@ func (*Create) Run(ctx context.Context, st store.Store) (*reporter.BenchmarkRun, for i := uint(0); i < *flags.StoreItemCount; i++ { dc.Start() - err := s.Set(imap.InternalMessageID(uint64(i)), data) + err := s.Set(imap.NewInternalMessageID(), data) dc.Stop() if err != nil { diff --git a/benchmarks/gluon_bench/store_benchmarks/utils.go b/benchmarks/gluon_bench/store_benchmarks/utils.go index 5f303541..581aad79 100644 --- a/benchmarks/gluon_bench/store_benchmarks/utils.go +++ b/benchmarks/gluon_bench/store_benchmarks/utils.go @@ -17,7 +17,7 @@ func CreateRandomState(st store.Store, count uint) ([]imap.InternalMessageID, er data := make([]byte, *flags.StoreItemSize) for i := uint(0); i < count; i++ { - uuid := imap.InternalMessageID(uint64(i)) + uuid := imap.NewInternalMessageID() if err := st.Set(uuid, data); err != nil { return nil, err diff --git a/imap/strong_types.go b/imap/strong_types.go index c05362fb..d8f40899 100644 --- a/imap/strong_types.go +++ b/imap/strong_types.go @@ -1,11 +1,11 @@ package imap import ( - "encoding/binary" "fmt" "strconv" "github.com/ProtonMail/gluon/internal/utils" + "github.com/google/uuid" ) type MailboxID string @@ -20,7 +20,9 @@ func (m MessageID) ShortID() string { return utils.ShortID(string(m)) } -type InternalMessageID uint64 +type InternalMessageID struct { + uuid.UUID +} type InternalMailboxID uint64 @@ -29,7 +31,7 @@ func (i InternalMailboxID) ShortID() string { } func (i InternalMessageID) ShortID() string { - return strconv.FormatUint(uint64(i), 10) + return utils.ShortID(i.String()) } func (i InternalMailboxID) String() string { @@ -37,23 +39,20 @@ func (i InternalMailboxID) String() string { } func (i InternalMessageID) String() string { - return strconv.FormatUint(uint64(i), 10) + return i.UUID.String() +} + +func NewInternalMessageID() InternalMessageID { + return InternalMessageID{UUID: uuid.New()} } func InternalMessageIDFromString(id string) (InternalMessageID, error) { - num, err := strconv.ParseUint(id, 10, 64) + num, err := uuid.Parse(id) if err != nil { - return 0, fmt.Errorf("invalid message id:%w", err) + return InternalMessageID{}, fmt.Errorf("invalid message id:%w", err) } - return InternalMessageID(num), nil -} - -func (i InternalMessageID) ToBytes() []byte { - bytes := make([]byte, 8) - binary.LittleEndian.PutUint64(bytes, uint64(i)) - - return bytes + return InternalMessageID{UUID: num}, nil } type UID uint32 diff --git a/internal/backend/connector_updates.go b/internal/backend/connector_updates.go index b9c00708..9e301885 100644 --- a/internal/backend/connector_updates.go +++ b/internal/backend/connector_updates.go @@ -210,7 +210,7 @@ func (user *user) applyMessagesCreated(ctx context.Context, update *imap.Message continue } - internalID = user.nextMessageID() + internalID = imap.NewInternalMessageID() literal, err := rfc822.SetHeaderValue(message.Literal, ids.InternalIDKey, internalID.String()) if err != nil { @@ -597,7 +597,7 @@ func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageU } // create new entry { - newInternalID := user.nextMessageID() + newInternalID := imap.NewInternalMessageID() literal, err := rfc822.SetHeaderValue(update.Literal, ids.InternalIDKey, newInternalID.String()) if err != nil { return fmt.Errorf("failed to set internal ID: %w", err) diff --git a/internal/backend/state_connector_impl.go b/internal/backend/state_connector_impl.go index 2044fcb2..c49879dd 100644 --- a/internal/backend/state_connector_impl.go +++ b/internal/backend/state_connector_impl.go @@ -69,10 +69,10 @@ func (sc *stateConnectorImpl) CreateMessage( msg, newLiteral, err := sc.connector.CreateMessage(ctx, mboxID, literal, flags, date) if err != nil { - return 0, imap.Message{}, nil, err + return imap.InternalMessageID{}, imap.Message{}, nil, err } - return sc.user.nextMessageID(), msg, newLiteral, nil + return imap.NewInternalMessageID(), msg, newLiteral, nil } func (sc *stateConnectorImpl) AddMessagesToMailbox( diff --git a/internal/backend/state_user_interface_impl.go b/internal/backend/state_user_interface_impl.go index 79032610..0710483b 100644 --- a/internal/backend/state_user_interface_impl.go +++ b/internal/backend/state_user_interface_impl.go @@ -93,7 +93,3 @@ func (s *StateUserInterfaceImpl) GetRecoveryMailboxID() ids.MailboxIDPair { RemoteID: ids.GluonInternalRecoveryMailboxRemoteID, } } - -func (s *StateUserInterfaceImpl) NextRecoveryMessageID() imap.InternalMessageID { - return s.u.nextMessageID() -} diff --git a/internal/backend/user.go b/internal/backend/user.go index 68869589..f8e5c0d3 100644 --- a/internal/backend/user.go +++ b/internal/backend/user.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "sync" - "sync/atomic" "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/imap" @@ -37,7 +36,6 @@ type user struct { updateWG sync.WaitGroup updateQuitCh chan struct{} - messageIDCounter uint64 globalUIDValidity imap.UID recoveryMailboxID imap.InternalMailboxID @@ -68,15 +66,6 @@ func newUser( return db.GetOrCreateMailbox(ctx, tx, mbox, delimiter, conn.GetUIDValidity()) }) - - if err != nil { - return nil, err - } - - // Get the last message ID from the database so we can resume our counter properly. - highestMessageID, err := db.ReadResult(ctx, database, func(ctx context.Context, client *ent.Client) (imap.InternalMessageID, error) { - return db.GetHighestMessageID(ctx, client) - }) if err != nil { return nil, err } @@ -93,7 +82,6 @@ func newUser( states: make(map[state.StateID]*state.State), updateQuitCh: make(chan struct{}), - messageIDCounter: uint64(highestMessageID), globalUIDValidity: conn.GetUIDValidity(), recoveryMailboxID: recoveryMBox.ID, @@ -311,10 +299,6 @@ func (user *user) closeStates() { } } -func (user *user) nextMessageID() imap.InternalMessageID { - return imap.InternalMessageID(atomic.AddUint64(&user.messageIDCounter, 1)) -} - func (user *user) cleanupStaleStoreData(ctx context.Context) error { storeIds, err := user.store.List() if err != nil { diff --git a/internal/db/ent/message.go b/internal/db/ent/message.go index bd4a10a7..3c8ac699 100644 --- a/internal/db/ent/message.go +++ b/internal/db/ent/message.go @@ -70,9 +70,11 @@ func (*Message) scanValues(columns []string) ([]interface{}, error) { values := make([]interface{}, len(columns)) for i := range columns { switch columns[i] { + case message.FieldID: + values[i] = new(imap.InternalMessageID) case message.FieldDeleted: values[i] = new(sql.NullBool) - case message.FieldID, message.FieldSize: + case message.FieldSize: values[i] = new(sql.NullInt64) case message.FieldRemoteID, message.FieldBody, message.FieldBodyStructure, message.FieldEnvelope: values[i] = new(sql.NullString) @@ -94,10 +96,10 @@ func (m *Message) assignValues(columns []string, values []interface{}) error { for i := range columns { switch columns[i] { case message.FieldID: - if value, ok := values[i].(*sql.NullInt64); !ok { + if value, ok := values[i].(*imap.InternalMessageID); !ok { return fmt.Errorf("unexpected type %T for field id", values[i]) - } else if value.Valid { - m.ID = imap.InternalMessageID(value.Int64) + } else if value != nil { + m.ID = *value } case message.FieldRemoteID: if value, ok := values[i].(*sql.NullString); !ok { diff --git a/internal/db/ent/message_create.go b/internal/db/ent/message_create.go index e1346e0c..c9d923a0 100644 --- a/internal/db/ent/message_create.go +++ b/internal/db/ent/message_create.go @@ -231,9 +231,12 @@ func (mc *MessageCreate) sqlSave(ctx context.Context) (*Message, error) { } return nil, err } - if _spec.ID.Value != _node.ID { - id := _spec.ID.Value.(int64) - _node.ID = imap.InternalMessageID(id) + if _spec.ID.Value != nil { + if id, ok := _spec.ID.Value.(*imap.InternalMessageID); ok { + _node.ID = *id + } else if err := _node.ID.Scan(_spec.ID.Value); err != nil { + return nil, err + } } return _node, nil } @@ -244,14 +247,14 @@ func (mc *MessageCreate) createSpec() (*Message, *sqlgraph.CreateSpec) { _spec = &sqlgraph.CreateSpec{ Table: message.Table, ID: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, + Type: field.TypeUUID, Column: message.FieldID, }, } ) if id, ok := mc.mutation.ID(); ok { _node.ID = id - _spec.ID.Value = id + _spec.ID.Value = &id } if value, ok := mc.mutation.RemoteID(); ok { _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ @@ -391,10 +394,6 @@ func (mcb *MessageCreateBulk) Save(ctx context.Context) ([]*Message, error) { return nil, err } mutation.id = &nodes[i].ID - if specs[i].ID.Value != nil && nodes[i].ID == 0 { - id := specs[i].ID.Value.(int64) - nodes[i].ID = imap.InternalMessageID(id) - } mutation.done = true return nodes[i], nil }) diff --git a/internal/db/ent/message_delete.go b/internal/db/ent/message_delete.go index d6d28f3e..b5094106 100644 --- a/internal/db/ent/message_delete.go +++ b/internal/db/ent/message_delete.go @@ -72,7 +72,7 @@ func (md *MessageDelete) sqlExec(ctx context.Context) (int, error) { Node: &sqlgraph.NodeSpec{ Table: message.Table, ID: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, + Type: field.TypeUUID, Column: message.FieldID, }, }, diff --git a/internal/db/ent/message_query.go b/internal/db/ent/message_query.go index bcdd39c1..84416328 100644 --- a/internal/db/ent/message_query.go +++ b/internal/db/ent/message_query.go @@ -515,7 +515,7 @@ func (mq *MessageQuery) querySpec() *sqlgraph.QuerySpec { Table: message.Table, Columns: message.Columns, ID: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, + Type: field.TypeUUID, Column: message.FieldID, }, }, diff --git a/internal/db/ent/message_update.go b/internal/db/ent/message_update.go index 955e0fc5..15564592 100644 --- a/internal/db/ent/message_update.go +++ b/internal/db/ent/message_update.go @@ -239,7 +239,7 @@ func (mu *MessageUpdate) sqlSave(ctx context.Context) (n int, err error) { Table: message.Table, Columns: message.Columns, ID: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, + Type: field.TypeUUID, Column: message.FieldID, }, }, @@ -661,7 +661,7 @@ func (muo *MessageUpdateOne) sqlSave(ctx context.Context) (_node *Message, err e Table: message.Table, Columns: message.Columns, ID: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, + Type: field.TypeUUID, Column: message.FieldID, }, }, diff --git a/internal/db/ent/messageflag.go b/internal/db/ent/messageflag.go index 748565a6..f1d82172 100644 --- a/internal/db/ent/messageflag.go +++ b/internal/db/ent/messageflag.go @@ -57,7 +57,7 @@ func (*MessageFlag) scanValues(columns []string) ([]interface{}, error) { case messageflag.FieldValue: values[i] = new(sql.NullString) case messageflag.ForeignKeys[0]: // message_flags - values[i] = new(sql.NullInt64) + values[i] = &sql.NullScanner{S: new(imap.InternalMessageID)} default: return nil, fmt.Errorf("unexpected column %q for type MessageFlag", columns[i]) } @@ -86,11 +86,11 @@ func (mf *MessageFlag) assignValues(columns []string, values []interface{}) erro mf.Value = value.String } case messageflag.ForeignKeys[0]: - if value, ok := values[i].(*sql.NullInt64); !ok { + if value, ok := values[i].(*sql.NullScanner); !ok { return fmt.Errorf("unexpected type %T for field message_flags", values[i]) } else if value.Valid { mf.message_flags = new(imap.InternalMessageID) - *mf.message_flags = imap.InternalMessageID(value.Int64) + *mf.message_flags = *value.S.(*imap.InternalMessageID) } } } diff --git a/internal/db/ent/messageflag_create.go b/internal/db/ent/messageflag_create.go index d2120aae..7b7bec11 100644 --- a/internal/db/ent/messageflag_create.go +++ b/internal/db/ent/messageflag_create.go @@ -169,7 +169,7 @@ func (mfc *MessageFlagCreate) createSpec() (*MessageFlag, *sqlgraph.CreateSpec) Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, + Type: field.TypeUUID, Column: message.FieldID, }, }, diff --git a/internal/db/ent/messageflag_update.go b/internal/db/ent/messageflag_update.go index 1ef91a89..923a9270 100644 --- a/internal/db/ent/messageflag_update.go +++ b/internal/db/ent/messageflag_update.go @@ -153,7 +153,7 @@ func (mfu *MessageFlagUpdate) sqlSave(ctx context.Context) (n int, err error) { Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, + Type: field.TypeUUID, Column: message.FieldID, }, }, @@ -169,7 +169,7 @@ func (mfu *MessageFlagUpdate) sqlSave(ctx context.Context) (n int, err error) { Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, + Type: field.TypeUUID, Column: message.FieldID, }, }, @@ -352,7 +352,7 @@ func (mfuo *MessageFlagUpdateOne) sqlSave(ctx context.Context) (_node *MessageFl Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, + Type: field.TypeUUID, Column: message.FieldID, }, }, @@ -368,7 +368,7 @@ func (mfuo *MessageFlagUpdateOne) sqlSave(ctx context.Context) (_node *MessageFl Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, + Type: field.TypeUUID, Column: message.FieldID, }, }, diff --git a/internal/db/ent/migrate/schema.go b/internal/db/ent/migrate/schema.go index 6b434dfe..545188bb 100644 --- a/internal/db/ent/migrate/schema.go +++ b/internal/db/ent/migrate/schema.go @@ -102,7 +102,7 @@ var ( } // MessagesColumns holds the columns for the "messages" table. MessagesColumns = []*schema.Column{ - {Name: "id", Type: field.TypeUint64, Increment: true}, + {Name: "id", Type: field.TypeUUID, Unique: true}, {Name: "remote_id", Type: field.TypeString, Unique: true, Nullable: true}, {Name: "date", Type: field.TypeTime}, {Name: "size", Type: field.TypeInt}, @@ -133,7 +133,7 @@ var ( MessageFlagsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "value", Type: field.TypeString}, - {Name: "message_flags", Type: field.TypeUint64, Nullable: true}, + {Name: "message_flags", Type: field.TypeUUID, Nullable: true}, } // MessageFlagsTable holds the schema information for the "message_flags" table. MessageFlagsTable = &schema.Table{ @@ -156,7 +156,7 @@ var ( {Name: "deleted", Type: field.TypeBool, Default: false}, {Name: "recent", Type: field.TypeBool, Default: true}, {Name: "mailbox_ui_ds", Type: field.TypeUint64, Nullable: true}, - {Name: "uid_message", Type: field.TypeUint64, Nullable: true}, + {Name: "uid_message", Type: field.TypeUUID, Nullable: true}, } // UIDsTable holds the schema information for the "ui_ds" table. UIDsTable = &schema.Table{ diff --git a/internal/db/ent/schema/message.go b/internal/db/ent/schema/message.go index 28a61f81..7a597cd1 100644 --- a/internal/db/ent/schema/message.go +++ b/internal/db/ent/schema/message.go @@ -17,7 +17,7 @@ type Message struct { // Fields of the Message. func (Message) Fields() []ent.Field { return []ent.Field{ - field.Uint64("id").GoType(imap.InternalMessageID(0)).Unique().Immutable(), + field.UUID("id", imap.NewInternalMessageID()).Unique().Immutable(), field.String("RemoteID").Optional().Unique().GoType(imap.MessageID("")), field.Time("Date"), field.Int("Size"), diff --git a/internal/db/ent/uid.go b/internal/db/ent/uid.go index aa6ffed6..5eb4b9f9 100644 --- a/internal/db/ent/uid.go +++ b/internal/db/ent/uid.go @@ -80,7 +80,7 @@ func (*UID) scanValues(columns []string) ([]interface{}, error) { case uid.ForeignKeys[0]: // mailbox_ui_ds values[i] = new(sql.NullInt64) case uid.ForeignKeys[1]: // uid_message - values[i] = new(sql.NullInt64) + values[i] = &sql.NullScanner{S: new(imap.InternalMessageID)} default: return nil, fmt.Errorf("unexpected column %q for type UID", columns[i]) } @@ -128,11 +128,11 @@ func (u *UID) assignValues(columns []string, values []interface{}) error { *u.mailbox_ui_ds = imap.InternalMailboxID(value.Int64) } case uid.ForeignKeys[1]: - if value, ok := values[i].(*sql.NullInt64); !ok { + if value, ok := values[i].(*sql.NullScanner); !ok { return fmt.Errorf("unexpected type %T for field uid_message", values[i]) } else if value.Valid { u.uid_message = new(imap.InternalMessageID) - *u.uid_message = imap.InternalMessageID(value.Int64) + *u.uid_message = *value.S.(*imap.InternalMessageID) } } } diff --git a/internal/db/ent/uid_create.go b/internal/db/ent/uid_create.go index 72fe2492..59d5565a 100644 --- a/internal/db/ent/uid_create.go +++ b/internal/db/ent/uid_create.go @@ -252,7 +252,7 @@ func (uc *UIDCreate) createSpec() (*UID, *sqlgraph.CreateSpec) { Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, + Type: field.TypeUUID, Column: message.FieldID, }, }, diff --git a/internal/db/ent/uid_update.go b/internal/db/ent/uid_update.go index 07850845..8eb1a062 100644 --- a/internal/db/ent/uid_update.go +++ b/internal/db/ent/uid_update.go @@ -235,7 +235,7 @@ func (uu *UIDUpdate) sqlSave(ctx context.Context) (n int, err error) { Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, + Type: field.TypeUUID, Column: message.FieldID, }, }, @@ -251,7 +251,7 @@ func (uu *UIDUpdate) sqlSave(ctx context.Context) (n int, err error) { Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, + Type: field.TypeUUID, Column: message.FieldID, }, }, @@ -550,7 +550,7 @@ func (uuo *UIDUpdateOne) sqlSave(ctx context.Context) (_node *UID, err error) { Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, + Type: field.TypeUUID, Column: message.FieldID, }, }, @@ -566,7 +566,7 @@ func (uuo *UIDUpdateOne) sqlSave(ctx context.Context) (_node *UID, err error) { Bidi: false, Target: &sqlgraph.EdgeTarget{ IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, + Type: field.TypeUUID, Column: message.FieldID, }, }, diff --git a/internal/db/mailbox.go b/internal/db/mailbox.go index 0e16dcc1..19e03975 100644 --- a/internal/db/mailbox.go +++ b/internal/db/mailbox.go @@ -347,16 +347,22 @@ func GetOrCreateMailbox(ctx context.Context, tx *ent.Tx, mbox imap.Mailbox, deli } func FilterMailboxContains(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID, messageIDs []ids.MessageIDPair) ([]imap.InternalMessageID, error) { - var result []imap.InternalMessageID + type result struct { + InternalID imap.InternalMessageID `json:"uid_message"` + } + + var r []result if err := client.UID.Query().Where(func(s *sql.Selector) { s.Where(sql.And(sql.EQ(uid.MailboxColumn, mboxID), sql.In(uid.MessageColumn, xslices.Map(messageIDs, func(id ids.MessageIDPair) interface{} { - return uint64(id.InternalID) + return id.InternalID })...))) s.Select(uid.MessageColumn) - }).Select().Scan(ctx, &result); err != nil { + }).Select().Scan(ctx, &r); err != nil { return nil, err } - return result, nil + return xslices.Map(r, func(r result) imap.InternalMessageID { + return r.InternalID + }), nil } diff --git a/internal/db/message.go b/internal/db/message.go index a1f15f7d..4d62f089 100644 --- a/internal/db/message.go +++ b/internal/db/message.go @@ -305,7 +305,7 @@ func GetMessageUIDsWithFlagsAfterAddOrUIDBump(ctx context.Context, client *ent.C s.Join(msgTable).On(s.C(uid.MessageColumn), msgTable.C(message.FieldID)) s.LeftJoin(flagTable).On(s.C(uid.MessageColumn), flagTable.C(messageflag.MessagesColumn)) s.Where(sql.And(sql.EQ(uid.MailboxColumn, mboxID), sql.In(s.C(uid.MessageColumn), xslices.Map(chunk, func(id imap.InternalMessageID) interface{} { - return uint64(id) + return id })...))) s.Select(msgTable.C(message.FieldRemoteID), sql.As(fmt.Sprintf("GROUP_CONCAT(%v)", flagTable.C(messageflag.FieldValue)), "flags"), s.C(uid.FieldRecent), s.C(uid.FieldDeleted), s.C(uid.FieldUID), s.C(uid.MessageColumn)) s.GroupBy(s.C(uid.MessageColumn)) @@ -590,7 +590,7 @@ func HasMessageWithRemoteID(ctx context.Context, client *ent.Client, id imap.Mes func GetMessageIDFromRemoteID(ctx context.Context, client *ent.Client, id imap.MessageID) (imap.InternalMessageID, error) { message, err := client.Message.Query().Where(message.RemoteID(id)).Select(message.FieldID).Only(ctx) if err != nil { - return 0, err + return imap.InternalMessageID{}, err } return message.ID, nil @@ -642,11 +642,11 @@ func NewFlagSet(msgUID *ent.UID, flags []*ent.MessageFlag) imap.FlagSet { func GetHighestMessageID(ctx context.Context, client *ent.Client) (imap.InternalMessageID, error) { message, err := client.Message.Query().Select(message.FieldID).Order(ent.Desc(message.FieldID)).Limit(1).All(ctx) if err != nil { - return 0, err + return imap.InternalMessageID{}, err } if len(message) == 0 { - return 0, nil + return imap.InternalMessageID{}, nil } return message[0].ID, nil diff --git a/internal/state/actions.go b/internal/state/actions.go index 5ff2f574..145ab168 100644 --- a/internal/state/actions.go +++ b/internal/state/actions.go @@ -178,7 +178,7 @@ func (state *State) actionCreateRecoveredMessage( flags imap.FlagSet, date time.Time, ) error { - internalID := state.user.NextRecoveryMessageID() + internalID := imap.NewInternalMessageID() remoteID := imap.MessageID(fmt.Sprintf("GLUON-RECOVERED-MESSAGE-%v", internalID)) parsedMessage, err := imap.NewParsedMessage(literal) diff --git a/internal/state/snapshot_messages_test.go b/internal/state/snapshot_messages_test.go index c61e6179..6be9134a 100644 --- a/internal/state/snapshot_messages_test.go +++ b/internal/state/snapshot_messages_test.go @@ -12,58 +12,65 @@ import ( func TestMessages(t *testing.T) { msg := newMsgList(8) - msg.insert(messageIDPair(1, "1"), 10, imap.NewFlagSet(imap.FlagSeen)) - msg.insert(messageIDPair(2, "2"), 20, imap.NewFlagSet(imap.FlagSeen)) - msg.insert(messageIDPair(3, "3"), 30, imap.NewFlagSet(imap.FlagSeen)) - msg.insert(messageIDPair(4, "4"), 40, imap.NewFlagSet(imap.FlagSeen)) - msg.insert(messageIDPair(5, "5"), 50, imap.NewFlagSet(imap.FlagSeen)) - msg.insert(messageIDPair(6, "6"), 60, imap.NewFlagSet(imap.FlagSeen)) - - msg.remove(2) - msg.remove(4) - msg.remove(6) + id1 := imap.NewInternalMessageID() + id2 := imap.NewInternalMessageID() + id3 := imap.NewInternalMessageID() + id4 := imap.NewInternalMessageID() + id5 := imap.NewInternalMessageID() + id6 := imap.NewInternalMessageID() + + msg.insert(messageIDPair(id1, "1"), 10, imap.NewFlagSet(imap.FlagSeen)) + msg.insert(messageIDPair(id2, "2"), 20, imap.NewFlagSet(imap.FlagSeen)) + msg.insert(messageIDPair(id3, "3"), 30, imap.NewFlagSet(imap.FlagSeen)) + msg.insert(messageIDPair(id4, "4"), 40, imap.NewFlagSet(imap.FlagSeen)) + msg.insert(messageIDPair(id5, "5"), 50, imap.NewFlagSet(imap.FlagSeen)) + msg.insert(messageIDPair(id6, "6"), 60, imap.NewFlagSet(imap.FlagSeen)) + + msg.remove(id2) + msg.remove(id4) + msg.remove(id6) { require.Equal(t, 3, msg.len()) } { - require.True(t, msg.has(1)) - require.True(t, msg.has(3)) - require.True(t, msg.has(5)) + require.True(t, msg.has(id1)) + require.True(t, msg.has(id3)) + require.True(t, msg.has(id5)) - require.False(t, msg.has(2)) - require.False(t, msg.has(4)) - require.False(t, msg.has(6)) + require.False(t, msg.has(id2)) + require.False(t, msg.has(id4)) + require.False(t, msg.has(id6)) } { - msg1, ok := msg.get(1) + msg1, ok := msg.get(id1) require.True(t, ok) require.Equal(t, imap.SeqID(1), msg1.Seq) require.Equal(t, imap.UID(10), msg1.UID) - _, ok = msg.get(2) + _, ok = msg.get(id2) require.False(t, ok) - msg3, ok := msg.get(3) + msg3, ok := msg.get(id3) require.True(t, ok) require.Equal(t, imap.SeqID(2), msg3.Seq) require.Equal(t, imap.UID(30), msg3.UID) - _, ok = msg.get(4) + _, ok = msg.get(id4) require.False(t, ok) - msg5, ok := msg.get(5) + msg5, ok := msg.get(id5) require.True(t, ok) require.Equal(t, imap.SeqID(3), msg5.Seq) require.Equal(t, imap.UID(50), msg5.UID) } { - require.Equal(t, must(msg.get(1)), must(msg.seq(1))) - require.Equal(t, must(msg.get(3)), must(msg.seq(2))) - require.Equal(t, must(msg.get(5)), must(msg.seq(3))) + require.Equal(t, must(msg.get(id1)), must(msg.seq(1))) + require.Equal(t, must(msg.get(id3)), must(msg.seq(2))) + require.Equal(t, must(msg.get(id5)), must(msg.seq(3))) } } @@ -71,12 +78,19 @@ func TestMessages(t *testing.T) { func TestMessageUIDRange(t *testing.T) { msg := newMsgList(8) - msg.insert(messageIDPair(1, "1"), 10, imap.NewFlagSet(imap.FlagSeen)) - msg.insert(messageIDPair(2, "2"), 20, imap.NewFlagSet(imap.FlagSeen)) - msg.insert(messageIDPair(3, "3"), 30, imap.NewFlagSet(imap.FlagSeen)) - msg.insert(messageIDPair(4, "4"), 40, imap.NewFlagSet(imap.FlagSeen)) - msg.insert(messageIDPair(5, "5"), 50, imap.NewFlagSet(imap.FlagSeen)) - msg.insert(messageIDPair(6, "6"), 60, imap.NewFlagSet(imap.FlagSeen)) + id1 := imap.NewInternalMessageID() + id2 := imap.NewInternalMessageID() + id3 := imap.NewInternalMessageID() + id4 := imap.NewInternalMessageID() + id5 := imap.NewInternalMessageID() + id6 := imap.NewInternalMessageID() + + msg.insert(messageIDPair(id1, "1"), 10, imap.NewFlagSet(imap.FlagSeen)) + msg.insert(messageIDPair(id2, "2"), 20, imap.NewFlagSet(imap.FlagSeen)) + msg.insert(messageIDPair(id3, "3"), 30, imap.NewFlagSet(imap.FlagSeen)) + msg.insert(messageIDPair(id4, "4"), 40, imap.NewFlagSet(imap.FlagSeen)) + msg.insert(messageIDPair(id5, "5"), 50, imap.NewFlagSet(imap.FlagSeen)) + msg.insert(messageIDPair(id6, "6"), 60, imap.NewFlagSet(imap.FlagSeen)) // UIDRange Higher than maximum { diff --git a/internal/state/user_interface.go b/internal/state/user_interface.go index 3ccd6d79..58ee1b12 100644 --- a/internal/state/user_interface.go +++ b/internal/state/user_interface.go @@ -33,6 +33,4 @@ type UserInterface interface { SetGlobalUIDValidity(imap.UID) GetRecoveryMailboxID() ids.MailboxIDPair - - NextRecoveryMessageID() imap.InternalMessageID } diff --git a/store/store_test.go b/store/store_test.go index 84bc0ef6..63377a9b 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -31,23 +31,35 @@ func TestOnDiskStore(t *testing.T) { } func testStore(t *testing.T, store store.Store) { - require.NoError(t, store.Set(1, []byte("literal1"))) - require.NoError(t, store.Set(2, []byte("literal2"))) - require.NoError(t, store.Set(3, []byte("literal3"))) + id1 := imap.NewInternalMessageID() + id2 := imap.NewInternalMessageID() + id3 := imap.NewInternalMessageID() - require.Equal(t, []byte("literal1"), must(store.Get(1))) - require.Equal(t, []byte("literal2"), must(store.Get(2))) - require.Equal(t, []byte("literal3"), must(store.Get(3))) + require.NoError(t, store.Set(id1, []byte("literal1"))) + require.NoError(t, store.Set(id2, []byte("literal2"))) + require.NoError(t, store.Set(id3, []byte("literal3"))) + + require.Equal(t, []byte("literal1"), must(store.Get(id1))) + require.Equal(t, []byte("literal2"), must(store.Get(id2))) + require.Equal(t, []byte("literal3"), must(store.Get(id3))) + + require.NoError(t, store.Delete(id1, id2, id3)) } func testStoreList(t *testing.T, store store.Store) { - require.NoError(t, store.Set(1, []byte("literal1"))) - require.NoError(t, store.Set(2, []byte("literal2"))) - require.NoError(t, store.Set(3, []byte("literal3"))) + id1 := imap.NewInternalMessageID() + id2 := imap.NewInternalMessageID() + id3 := imap.NewInternalMessageID() + + require.NoError(t, store.Set(id1, []byte("literal1"))) + require.NoError(t, store.Set(id2, []byte("literal2"))) + require.NoError(t, store.Set(id3, []byte("literal3"))) list, err := store.List() require.NoError(t, err) - require.ElementsMatch(t, list, []imap.InternalMessageID{1, 2, 3}) + require.ElementsMatch(t, list, []imap.InternalMessageID{id1, id2, id3}) + + require.NoError(t, store.Delete(id1, id2, id3)) } func must[T any](val T, err error) T { diff --git a/tests/append_test.go b/tests/append_test.go index 5dc8cfa7..21445330 100644 --- a/tests/append_test.go +++ b/tests/append_test.go @@ -11,7 +11,6 @@ import ( "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/ids" goimap "github.com/emersion/go-imap" uidplus "github.com/emersion/go-imap-uidplus" "github.com/emersion/go-imap/client" @@ -337,24 +336,25 @@ func TestGODT2007AppendInternalIDPresentOnDeletedMessage(t *testing.T) { // Check if the header is correctly set. result := newFetchCommand(t, client).withItems("UID", "BODY[HEADER]").fetch("1") result.forSeqNum(1, func(builder *validatorBuilder) { - builder.ignoreFlags().wantSection("BODY[HEADER]", fmt.Sprintf("%v: 1", ids.InternalIDKey), "To: foo@bar.com\r\n") + builder.ignoreFlags().wantSectionAndSkipGLUONHeaderOrPanic("BODY[HEADER]", "To: foo@bar.com\r\n") builder.wantUID(1) }) result.checkAndRequireMessageCount(1) } + appendedMessage := fetchMessageBody(t, client, 1) s.messageDeleted("user", messageID) s.flush("user") // Add the same message back with the same id - require.NoError(t, doAppendWithClient(client, mailboxName, fmt.Sprintf("%v: 1\r\nTo: foo@bar.com\r\n", ids.InternalIDKey), time.Now())) + require.NoError(t, doAppendWithClient(client, mailboxName, appendedMessage, time.Now())) { // The message should have been created with a new internal id result := newFetchCommand(t, client).withItems("UID", "BODY[HEADER]").fetch("1") result.forSeqNum(1, func(builder *validatorBuilder) { // The header value appears twice because we don't delete existing headers, we only add new ones. - builder.ignoreFlags().wantSection("BODY[HEADER]", fmt.Sprintf("%v: 2", ids.InternalIDKey), fmt.Sprintf("%v: 1", ids.InternalIDKey), "To: foo@bar.com\r\n") + builder.ignoreFlags().wantSectionAndSkipGLUONHeaderOrPanic("BODY[HEADER]", appendedMessage) builder.wantUID(2) }) result.checkAndRequireMessageCount(1) diff --git a/tests/counts_test.go b/tests/counts_test.go index ff59e8da..6badbb0e 100644 --- a/tests/counts_test.go +++ b/tests/counts_test.go @@ -1,11 +1,11 @@ package tests import ( - "github.com/ProtonMail/gluon/internal/ids" "testing" "time" "github.com/ProtonMail/gluon/events" + "github.com/ProtonMail/gluon/internal/ids" goimap "github.com/emersion/go-imap" "github.com/emersion/go-imap/client" "github.com/stretchr/testify/require" diff --git a/tests/fetch_test.go b/tests/fetch_test.go index c985b9d1..94c66307 100644 --- a/tests/fetch_test.go +++ b/tests/fetch_test.go @@ -10,6 +10,7 @@ import ( "github.com/ProtonMail/gluon/internal/ids" goimap "github.com/emersion/go-imap" "github.com/emersion/go-imap/client" + "github.com/google/uuid" "github.com/stretchr/testify/require" ) @@ -553,7 +554,7 @@ func afternoonMeetingMessageDataSize() uint32 { } func afternoonMeetingMessageDataSizeWithExtraHeader() uint32 { - return afternoonMeetingMessageDataSize() + uint32(len(ids.InternalIDKey)) + 5 + return afternoonMeetingMessageDataSize() + uint32(len(ids.InternalIDKey)) + uint32(len(uuid.NewString())+4) } func fillAndSelectMailboxWithMultipleEntries(t *testing.T, client *client.Client) { diff --git a/tests/helper_test.go b/tests/helper_test.go index e84ab6f3..722ef865 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "os" + "strconv" "strings" "testing" "time" @@ -193,6 +194,19 @@ func createSeqSet(sequence string) *goimap.SeqSet { return sequenceSet } +func fetchMessageBody(tb testing.TB, cl *client.Client, seq uint32) string { + result := newFetchCommand(tb, cl).withItems("BODY[]").fetch(strconv.FormatUint(uint64(seq), 10)) + section, err := goimap.ParseBodySectionName("BODY[]") + require.NoError(tb, err) + + literal := getBodySection(result.messages[0], section) + + bytes, err := io.ReadAll(literal) + require.NoError(tb, err) + + return string(bytes) +} + // Helper to validate go-imap-client's message Envelope. type envelopeValidator struct { validateDateTime func(testing.TB, time.Time) @@ -507,6 +521,21 @@ func skipGLUONHeader(message string) string { return message } +func skipGLUONHeaderOrPanic(message string) string { + if keyIndex := strings.Index(message, ids.InternalIDKey); keyIndex != -1 { + newLineIndex := strings.Index(message[keyIndex:], "\n") + if newLineIndex < 0 { + panic("Could not find terminating new line") + } + + message = message[0:keyIndex] + message[keyIndex+newLineIndex+1:] + } else { + panic("Could not find Gluon header") + } + + return message +} + func (vb *validatorBuilder) wantSection(sectionStr goimap.FetchItem, lines ...string) *validatorBuilder { section, err := goimap.ParseBodySectionName(sectionStr) if err != nil { @@ -569,6 +598,23 @@ func (vb *validatorBuilder) wantSectionAndSkipGLUONHeader(sectionStr goimap.Fetc return vb } +func (vb *validatorBuilder) wantSectionAndSkipGLUONHeaderOrPanic(sectionStr goimap.FetchItem, expected ...string) *validatorBuilder { + section, err := goimap.ParseBodySectionName(sectionStr) + if err != nil { + panic("Failed to parse section string") + } + + vb.validateBody = append(vb.validateBody, func(t testing.TB, message *goimap.Message) { + literal := getBodySection(message, section) + require.NotNil(t, literal) + bytes, err := io.ReadAll(literal) + require.NoError(t, err) + require.Equal(t, skipGLUONHeaderOrPanic(string(bytes)), strings.Join(expected, "\r\n")) + }) + + return vb +} + func (vb *validatorBuilder) wantSectionBytes(sectionStr goimap.FetchItem, fn func(testing.TB, []byte)) *validatorBuilder { section, err := goimap.ParseBodySectionName(sectionStr) if err != nil { diff --git a/tests/recovery_mailbox_test.go b/tests/recovery_mailbox_test.go index 6fc88b80..651367cd 100644 --- a/tests/recovery_mailbox_test.go +++ b/tests/recovery_mailbox_test.go @@ -4,14 +4,15 @@ import ( "bytes" "context" "fmt" + "testing" + "time" + "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/ids" goimap "github.com/emersion/go-imap" "github.com/emersion/go-imap/client" "github.com/stretchr/testify/require" - "testing" - "time" ) func TestRecoveryMBoxNotVisibleWhenEmpty(t *testing.T) { @@ -102,7 +103,7 @@ func TestRecoveryMBoxCanBeMovedOutOf(t *testing.T) { // Check that message has the new internal ID header. newFetchCommand(t, client).withItems("BODY[]").fetch("1").forSeqNum(1, func(builder *validatorBuilder) { builder.ignoreFlags() - builder.wantSection("BODY[]", fmt.Sprintf("%v: 2", ids.InternalIDKey), "To: Test@test.com") + builder.wantSectionAndSkipGLUONHeaderOrPanic("BODY[]", "To: Test@test.com") }).checkAndRequireMessageCount(1) } }) @@ -141,7 +142,7 @@ func TestRecoveryMBoxCanBeCopiedOutOf(t *testing.T) { // Check that message has the new internal ID header. newFetchCommand(t, client).withItems("BODY[]").fetch("1").forSeqNum(1, func(builder *validatorBuilder) { builder.ignoreFlags() - builder.wantSection("BODY[]", fmt.Sprintf("%v: 2", ids.InternalIDKey), "To: Test@test.com") + builder.wantSectionAndSkipGLUONHeaderOrPanic("BODY[]", "To: Test@test.com") }).checkAndRequireMessageCount(1) } }) @@ -220,10 +221,12 @@ func TestRecoveryMBoxCanBeCopiedOutOfDedup(t *testing.T) { require.NoError(t, err) newFetchCommand(t, client).withItems("BODY[]").fetch("1").forSeqNum(1, func(builder *validatorBuilder) { builder.ignoreFlags() - builder.wantSection("BODY[]", fmt.Sprintf("%v: 2", ids.InternalIDKey), "To: Test@test.com") + builder.wantSectionAndSkipGLUONHeaderOrPanic("BODY[]", "To: Test@test.com") }).checkAndRequireMessageCount(1) } + msgInInbox := fetchMessageBody(t, client, 1) + // Copy message out of recovery, triggers insert will return the same ID. status, err := client.Select(ids.GluonRecoveryMailboxName, false) require.NoError(t, err) @@ -243,7 +246,7 @@ func TestRecoveryMBoxCanBeCopiedOutOfDedup(t *testing.T) { // Check that message has the new internal ID header. newFetchCommand(t, client).withItems("BODY[]").fetch("1").forSeqNum(1, func(builder *validatorBuilder) { builder.ignoreFlags() - builder.wantSection("BODY[]", fmt.Sprintf("%v: 2", ids.InternalIDKey), "To: Test@test.com") + builder.wantSection("BODY[]", msgInInbox) }).checkAndRequireMessageCount(1) } }) @@ -270,10 +273,12 @@ func TestRecoveryMBoxCanBeMovedOutOfDedup(t *testing.T) { require.NoError(t, err) newFetchCommand(t, client).withItems("BODY[]").fetch("1").forSeqNum(1, func(builder *validatorBuilder) { builder.ignoreFlags() - builder.wantSection("BODY[]", fmt.Sprintf("%v: 2", ids.InternalIDKey), "To: Test@test.com") + builder.wantSectionAndSkipGLUONHeaderOrPanic("BODY[]", "To: Test@test.com") }).checkAndRequireMessageCount(1) } + msgInInbox := fetchMessageBody(t, client, 1) + // Copy message out of recovery, triggers insert will return the same ID. status, err := client.Select(ids.GluonRecoveryMailboxName, false) require.NoError(t, err) @@ -293,7 +298,7 @@ func TestRecoveryMBoxCanBeMovedOutOfDedup(t *testing.T) { // Check that message has the new internal ID header. newFetchCommand(t, client).withItems("BODY[]").fetch("1").forSeqNum(1, func(builder *validatorBuilder) { builder.ignoreFlags() - builder.wantSection("BODY[]", fmt.Sprintf("%v: 2", ids.InternalIDKey), "To: Test@test.com") + builder.wantSection("BODY[]", msgInInbox) }).checkAndRequireMessageCount(1) } }) diff --git a/tests/store_cleanup_test.go b/tests/store_cleanup_test.go index 3aa15f26..802b0e2c 100644 --- a/tests/store_cleanup_test.go +++ b/tests/store_cleanup_test.go @@ -24,7 +24,7 @@ func (t *TestStoreBuilder) New(dir, userID string, passphrase []byte) (store.Sto } testStoreBuilderTestIDs := []imap.InternalMessageID{ - 20414124, 234534834634, 13244367346734098, 2341234234234, + imap.NewInternalMessageID(), imap.NewInternalMessageID(), imap.NewInternalMessageID(), imap.NewInternalMessageID(), } for _, id := range testStoreBuilderTestIDs {