From 3319760d4ca2b734721ae74b72bfd9b26a2e81d7 Mon Sep 17 00:00:00 2001 From: Gurjot Singh <111540954+gusin13@users.noreply.github.com> Date: Fri, 10 Jan 2025 17:52:26 +0530 Subject: [PATCH] fix: batch processing in BTC Subscription Poller (#31) --- cmd/staking-expiry-checker/main.go | 2 +- config/config-docker.yml | 9 ++- config/config-local.yml | 9 ++- internal/config/db.go | 13 +++- internal/db/dbclient.go | 61 +++++++++++++++- internal/db/delegation.go | 42 +++++++---- internal/db/interface.go | 6 +- internal/db/model/delegation.go | 15 ++++ internal/db/model/pagination.go | 27 +++++++ internal/observability/metrics/metrics.go | 16 +++- internal/services/pollers.go | 89 ++++++++++++++--------- internal/services/service.go | 12 ++- tests/mocks/mock_db_client.go | 23 +++--- 13 files changed, 241 insertions(+), 83 deletions(-) create mode 100644 internal/db/model/pagination.go diff --git a/cmd/staking-expiry-checker/main.go b/cmd/staking-expiry-checker/main.go index 09826cd..be99062 100644 --- a/cmd/staking-expiry-checker/main.go +++ b/cmd/staking-expiry-checker/main.go @@ -47,7 +47,7 @@ func main() { defer cancel() // Create DB client - dbClient, err := db.New(ctx, cfg.Db) + dbClient, err := db.New(ctx, &cfg.Db) if err != nil { log.Fatal().Err(err).Msg("error while creating db client") } diff --git a/config/config-docker.yml b/config/config-docker.yml index 02c79b0..22a8a66 100644 --- a/config/config-docker.yml +++ b/config/config-docker.yml @@ -1,16 +1,17 @@ pollers: log-level: debug expiry-checker: - interval: 5s - timeout: 10s + interval: 10s + timeout: 100s btc-subscriber: - interval: 5s - timeout: 10s + interval: 10s + timeout: 100s db: username: root password: example address: "mongodb://localhost:27017" db-name: staking-api-service + max-pagination-limit: 1000 btc: rpchost: 127.0.0.1:38332 rpcuser: rpcuser diff --git a/config/config-local.yml b/config/config-local.yml index 713f6a7..7626445 100644 --- a/config/config-local.yml +++ b/config/config-local.yml @@ -1,16 +1,17 @@ pollers: log-level: debug expiry-checker: - interval: 5s - timeout: 1000s + interval: 10s + timeout: 100s btc-subscriber: - interval: 5s - timeout: 1000s + interval: 10s + timeout: 100s db: username: root password: example address: "mongodb://localhost:27017" db-name: staking-api-service + max-pagination-limit: 1000 btc: rpchost: 127.0.0.1:38332 rpcuser: rpcuser diff --git a/internal/config/db.go b/internal/config/db.go index 03fbc94..c39b1e5 100644 --- a/internal/config/db.go +++ b/internal/config/db.go @@ -7,10 +7,11 @@ import ( ) type DbConfig struct { - Username string `mapstructure:"username"` - Password string `mapstructure:"password"` - DbName string `mapstructure:"db-name"` - Address string `mapstructure:"address"` + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + DbName string `mapstructure:"db-name"` + Address string `mapstructure:"address"` + MaxPaginationLimit int64 `mapstructure:"max-pagination-limit"` } func (cfg *DbConfig) Validate() error { @@ -57,5 +58,9 @@ func (cfg *DbConfig) Validate() error { return fmt.Errorf("port number must be between 1024 and 65535 (inclusive)") } + if cfg.MaxPaginationLimit < 2 { + return fmt.Errorf("max pagination limit must be greater than 1") + } + return nil } diff --git a/internal/db/dbclient.go b/internal/db/dbclient.go index b4b244c..3b985a3 100644 --- a/internal/db/dbclient.go +++ b/internal/db/dbclient.go @@ -3,6 +3,7 @@ package db import ( "context" + "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" @@ -12,9 +13,15 @@ import ( type Database struct { dbName string client *mongo.Client + cfg *config.DbConfig } -func New(ctx context.Context, cfg config.DbConfig) (*Database, error) { +type DbResultMap[T any] struct { + Data []T `json:"data"` + PaginationToken string `json:"paginationToken"` +} + +func New(ctx context.Context, cfg *config.DbConfig) (*Database, error) { credential := options.Credential{ Username: cfg.Username, Password: cfg.Password, @@ -28,6 +35,7 @@ func New(ctx context.Context, cfg config.DbConfig) (*Database, error) { return &Database{ dbName: cfg.DbName, client: client, + cfg: cfg, }, nil } @@ -42,3 +50,54 @@ func (db *Database) Ping(ctx context.Context) error { func (db *Database) Shutdown(ctx context.Context) error { return db.client.Disconnect(ctx) } + +/* +Builds the result map with a pagination token. +If the result length exceeds the maximum limit, it returns the map with a token. +Otherwise, it returns the map with an empty token. Note that the pagination +limit is the maximum number of results to return. +For example, if the limit is 10, it fetches 11 but returns only 10. +The last result is used to generate the pagination token. +*/ +func toResultMapWithPaginationToken[T any](paginationLimit int64, result []T, paginationKeyBuilder func(T) (string, error)) (*DbResultMap[T], error) { + if len(result) > int(paginationLimit) { + result = result[:paginationLimit] + paginationToken, err := paginationKeyBuilder(result[len(result)-1]) + if err != nil { + return nil, err + } + return &DbResultMap[T]{ + Data: result, + PaginationToken: paginationToken, + }, nil + } + + return &DbResultMap[T]{ + Data: result, + PaginationToken: "", + }, nil +} + +// Finds documents in the collection with pagination in returned results. +func findWithPagination[T any]( + ctx context.Context, client *mongo.Collection, filter bson.M, + options *options.FindOptions, limit int64, + paginationKeyBuilder func(T) (string, error), +) (*DbResultMap[T], error) { + // Always fetch one more than the limit to check if there are more results + // this is used to generate the pagination token + options.SetLimit(limit + 1) + + cursor, err := client.Find(ctx, filter, options) + if err != nil { + return nil, err + } + defer cursor.Close(ctx) + + var result []T + if err = cursor.All(ctx, &result); err != nil { + return nil, err + } + + return toResultMapWithPaginationToken(limit, result, paginationKeyBuilder) +} diff --git a/internal/db/delegation.go b/internal/db/delegation.go index a02365e..d636fd0 100644 --- a/internal/db/delegation.go +++ b/internal/db/delegation.go @@ -107,30 +107,42 @@ func (db *Database) GetBTCDelegationByStakingTxHash( func (db *Database) GetBTCDelegationsByStates( ctx context.Context, states []types.DelegationState, -) ([]*model.DelegationDocument, error) { - // Convert states to a slice of strings + paginationToken string, +) (*DbResultMap[model.DelegationDocument], error) { + // Convert states to strings stateStrings := make([]string, len(states)) for i, state := range states { stateStrings[i] = state.ToString() } - filter := bson.M{"state": bson.M{"$in": stateStrings}} - opts := options.Find().SetLimit(200) // to prevent large result sets - - cursor, err := db.client.Database(db.dbName). - Collection(model.DelegationsCollection). - Find(ctx, filter, opts) - if err != nil { - return nil, err + // Build filter + filter := bson.M{ + "state": bson.M{"$in": stateStrings}, } - defer cursor.Close(ctx) - var delegations []*model.DelegationDocument - if err := cursor.All(ctx, &delegations); err != nil { - return nil, err + // Setup options + options := options.Find() + options.SetSort(bson.M{"_id": 1}) + + // Decode pagination token if it exists + if paginationToken != "" { + decodedToken, err := model.DecodePaginationToken[model.DelegationScanPagination](paginationToken) + if err != nil { + return nil, &InvalidPaginationTokenError{ + Message: "Invalid pagination token", + } + } + filter["_id"] = bson.M{"$gt": decodedToken.StakingTxHashHex} } - return delegations, nil + return findWithPagination( + ctx, + db.client.Database(db.dbName).Collection(model.DelegationsCollection), + filter, + options, + db.cfg.MaxPaginationLimit, + model.BuildDelegationScanPaginationToken, + ) } func (db *Database) GetBTCDelegationState( diff --git a/internal/db/interface.go b/internal/db/interface.go index 02b78ab..a9b5800 100644 --- a/internal/db/interface.go +++ b/internal/db/interface.go @@ -39,6 +39,10 @@ type DbInterface interface { GetBTCDelegationByStakingTxHash( ctx context.Context, stakingTxHash string, ) (*model.DelegationDocument, error) - GetBTCDelegationsByStates(ctx context.Context, states []types.DelegationState) ([]*model.DelegationDocument, error) + GetBTCDelegationsByStates( + ctx context.Context, + states []types.DelegationState, + paginationToken string, + ) (*DbResultMap[model.DelegationDocument], error) GetBTCDelegationState(ctx context.Context, stakingTxHash string) (*types.DelegationState, error) } diff --git a/internal/db/model/delegation.go b/internal/db/model/delegation.go index c8b2e26..ff6888e 100644 --- a/internal/db/model/delegation.go +++ b/internal/db/model/delegation.go @@ -18,3 +18,18 @@ type DelegationDocument struct { StakingTx *TimelockTransaction `bson:"staking_tx"` // Always exist UnbondingTx *TimelockTransaction `bson:"unbonding_tx,omitempty"` } + +type DelegationScanPagination struct { + StakingTxHashHex string `json:"staking_tx_hash_hex"` +} + +func BuildDelegationScanPaginationToken(d DelegationDocument) (string, error) { + page := &DelegationScanPagination{ + StakingTxHashHex: d.StakingTxHashHex, + } + token, err := GetPaginationToken(page) + if err != nil { + return "", err + } + return token, nil +} diff --git a/internal/db/model/pagination.go b/internal/db/model/pagination.go new file mode 100644 index 0000000..2f93f65 --- /dev/null +++ b/internal/db/model/pagination.go @@ -0,0 +1,27 @@ +package model + +import ( + "encoding/base64" + "encoding/json" +) + +func DecodePaginationToken[T any](token string) (*T, error) { + tokenBytes, err := base64.URLEncoding.DecodeString(token) + if err != nil { + return nil, err + } + var d T + err = json.Unmarshal(tokenBytes, &d) + if err != nil { + return nil, err + } + return &d, nil +} + +func GetPaginationToken[PaginationType any](d PaginationType) (string, error) { + tokenBytes, err := json.Marshal(d) + if err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(tokenBytes), nil +} diff --git a/internal/observability/metrics/metrics.go b/internal/observability/metrics/metrics.go index d711ba8..2c82893 100644 --- a/internal/observability/metrics/metrics.go +++ b/internal/observability/metrics/metrics.go @@ -30,7 +30,7 @@ func (O Outcome) String() string { var ( once sync.Once metricsRouter *chi.Mux - pollDurationHistogram *prometheus.HistogramVec + pollerDurationHistogram *prometheus.HistogramVec btcClientDurationHistogram *prometheus.HistogramVec invalidTransactionsCounter *prometheus.CounterVec failedVerifyingUnbondingTxsCounter prometheus.Counter @@ -74,13 +74,13 @@ func initMetricsRouter(metricsPort int) { // registerMetrics initializes and register the Prometheus metrics. func registerMetrics() { defaultHistogramBucketsSeconds := []float64{0.1, 0.5, 1, 2.5, 5, 10, 30} - pollDurationHistogram = prometheus.NewHistogramVec( + pollerDurationHistogram = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Name: "poll_duration_seconds", Help: "Histogram of poll durations in seconds.", Buckets: defaultHistogramBucketsSeconds, }, - []string{"status"}, + []string{"poller_name", "status"}, ) btcClientDurationHistogram = prometheus.NewHistogramVec( @@ -124,7 +124,7 @@ func registerMetrics() { ) prometheus.MustRegister( - pollDurationHistogram, + pollerDurationHistogram, btcClientDurationHistogram, invalidTransactionsCounter, failedVerifyingUnbondingTxsCounter, @@ -179,3 +179,11 @@ func IncrementFailedVerifyingStakingWithdrawalTxCounter() { func IncrementFailedVerifyingUnbondingWithdrawalTxCounter() { failedVerifyingUnbondingWithdrawalTxsCounter.Inc() } + +func ObservePollerDuration(pollerName string, duration time.Duration, err error) { + status := "success" + if err != nil { + status = "failure" + } + pollerDurationHistogram.WithLabelValues(pollerName, status).Observe(duration.Seconds()) +} diff --git a/internal/services/pollers.go b/internal/services/pollers.go index b1fc286..511a030 100644 --- a/internal/services/pollers.go +++ b/internal/services/pollers.go @@ -10,49 +10,65 @@ import ( ) func (s *Service) processBTCSubscriber(ctx context.Context) *types.Error { - // Get delegations that need BTC notifications - delegations, err := s.db.GetBTCDelegationsByStates(ctx, []types.DelegationState{ - types.Unbonded, - types.UnbondingRequested, - }) - if err != nil { - log.Error().Err(err).Msg("Failed to get delegations for BTC subscription") - return types.NewInternalServiceError(err) - } - - if len(delegations) == 0 { - log.Debug().Msg("No delegations found for BTC subscription") - return nil - } - - // Process each delegation - for _, delegation := range delegations { - if s.trackedSubs.IsSubscribed(delegation.StakingTxHashHex) { - continue - } - - err := s.registerStakingSpendNotification( - delegation.StakingTxHashHex, - delegation.StakingTx.TxHex, - uint32(delegation.StakingTx.OutputIndex), - uint32(delegation.StakingTx.StartHeight), + var ( + pageToken = "" + totalProcessed = 0 + totalSubscribed = 0 + ) + for { + result, err := s.db.GetBTCDelegationsByStates( + ctx, + []types.DelegationState{ + types.Unbonded, + types.UnbondingRequested, + }, + pageToken, ) if err != nil { - log.Error(). - Err(err). - Str("stakingTxHash", delegation.StakingTxHashHex). - Msg("Failed to register staking spend notification") + log.Error().Err(err).Msg("Failed to get delegations for BTC subscription") return types.NewInternalServiceError(err) } - // Add to tracked subscriptions after successful registration - s.trackedSubs.AddSubscription(delegation.StakingTxHashHex) + totalProcessed += len(result.Data) + + // Process batch + for _, delegation := range result.Data { + if s.trackedSubs.IsSubscribed(delegation.StakingTxHashHex) { + continue + } + + if err := s.registerStakingSpendNotification( + delegation.StakingTxHashHex, + delegation.StakingTx.TxHex, + uint32(delegation.StakingTx.OutputIndex), + uint32(delegation.StakingTx.StartHeight), + ); err != nil { + log.Error(). + Err(err). + Str("stakingTxHash", delegation.StakingTxHashHex). + Msg("Failed to register staking spend notification") + return types.NewInternalServiceError(err) + } + + s.trackedSubs.AddSubscription(delegation.StakingTxHashHex) + totalSubscribed++ + + log.Debug(). + Str("stakingTxHash", delegation.StakingTxHashHex). + Msg("Successfully registered BTC notification") + } - log.Debug(). - Str("stakingTxHash", delegation.StakingTxHashHex). - Msg("Successfully registered BTC notification") + pageToken = result.PaginationToken + if pageToken == "" { + break + } } + log.Info(). + Int("total_processed", totalProcessed). + Int("total_subscribed", totalSubscribed). + Msg("BTC subscription processing completed") + return nil } @@ -63,7 +79,8 @@ func (s *Service) processExpiredDelegations(ctx context.Context) *types.Error { return types.NewInternalServiceError(err) } - // Single batch of expired delegations + // Process a single batch of expired delegations without pagination. + // Since we delete each delegation after processing it, pagination is not needed. expiredDelegations, err := s.db.FindExpiredDelegations(ctx, uint64(btcTip)) if err != nil { log.Error().Err(err).Msg("Error finding expired delegations") diff --git a/internal/services/service.go b/internal/services/service.go index 5f55974..c37277c 100644 --- a/internal/services/service.go +++ b/internal/services/service.go @@ -103,9 +103,13 @@ func (s *Service) startExpiryPoller(ctx context.Context) { select { case <-ticker.C: pollingCtx, cancel := context.WithTimeout(ctx, s.cfg.Pollers.ExpiryChecker.Timeout) - if err := s.processExpiredDelegations(pollingCtx); err != nil { + start := time.Now() + err := s.processExpiredDelegations(pollingCtx) + if err != nil { log.Error().Err(err).Msg("Error processing expired delegations") } + duration := time.Since(start) + metrics.ObservePollerDuration("expiry_poller", duration, err) cancel() case <-ctx.Done(): log.Info().Msg("Expiry poller stopped due to context cancellation") @@ -127,9 +131,13 @@ func (s *Service) startBTCSubscriberPoller(ctx context.Context) { select { case <-ticker.C: pollingCtx, cancel := context.WithTimeout(ctx, s.cfg.Pollers.BtcSubscriber.Timeout) - if err := s.processBTCSubscriber(pollingCtx); err != nil { + start := time.Now() + err := s.processBTCSubscriber(pollingCtx) + if err != nil { log.Error().Err(err).Msg("Error processing BTC subscriptions") } + duration := time.Since(start) + metrics.ObservePollerDuration("btc_subscriber_poller", duration, err) cancel() case <-ctx.Done(): log.Info().Msg("BTC subscriber poller stopped due to context cancellation") diff --git a/tests/mocks/mock_db_client.go b/tests/mocks/mock_db_client.go index 4455396..0318805 100644 --- a/tests/mocks/mock_db_client.go +++ b/tests/mocks/mock_db_client.go @@ -5,6 +5,7 @@ package mocks import ( context "context" + db "github.com/babylonlabs-io/staking-expiry-checker/internal/db" mock "github.com/stretchr/testify/mock" model "github.com/babylonlabs-io/staking-expiry-checker/internal/db/model" @@ -127,29 +128,29 @@ func (_m *DbInterface) GetBTCDelegationState(ctx context.Context, stakingTxHash return r0, r1 } -// GetBTCDelegationsByStates provides a mock function with given fields: ctx, states -func (_m *DbInterface) GetBTCDelegationsByStates(ctx context.Context, states []types.DelegationState) ([]*model.DelegationDocument, error) { - ret := _m.Called(ctx, states) +// GetBTCDelegationsByStates provides a mock function with given fields: ctx, states, paginationToken +func (_m *DbInterface) GetBTCDelegationsByStates(ctx context.Context, states []types.DelegationState, paginationToken string) (*db.DbResultMap[model.DelegationDocument], error) { + ret := _m.Called(ctx, states, paginationToken) if len(ret) == 0 { panic("no return value specified for GetBTCDelegationsByStates") } - var r0 []*model.DelegationDocument + var r0 *db.DbResultMap[model.DelegationDocument] var r1 error - if rf, ok := ret.Get(0).(func(context.Context, []types.DelegationState) ([]*model.DelegationDocument, error)); ok { - return rf(ctx, states) + if rf, ok := ret.Get(0).(func(context.Context, []types.DelegationState, string) (*db.DbResultMap[model.DelegationDocument], error)); ok { + return rf(ctx, states, paginationToken) } - if rf, ok := ret.Get(0).(func(context.Context, []types.DelegationState) []*model.DelegationDocument); ok { - r0 = rf(ctx, states) + if rf, ok := ret.Get(0).(func(context.Context, []types.DelegationState, string) *db.DbResultMap[model.DelegationDocument]); ok { + r0 = rf(ctx, states, paginationToken) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*model.DelegationDocument) + r0 = ret.Get(0).(*db.DbResultMap[model.DelegationDocument]) } } - if rf, ok := ret.Get(1).(func(context.Context, []types.DelegationState) error); ok { - r1 = rf(ctx, states) + if rf, ok := ret.Get(1).(func(context.Context, []types.DelegationState, string) error); ok { + r1 = rf(ctx, states, paginationToken) } else { r1 = ret.Error(1) }