diff --git a/server/grpc_service.go b/server/grpc_service.go index 55639474ad0..649d02a37b4 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -44,6 +44,7 @@ import ( "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/mcs/utils/constant" + "github.com/tikv/pd/pkg/ratelimit" "github.com/tikv/pd/pkg/storage/endpoint" "github.com/tikv/pd/pkg/storage/kv" "github.com/tikv/pd/pkg/utils/grpcutil" @@ -277,14 +278,12 @@ func (s *GrpcServer) GetClusterInfo(context.Context, *pdpb.GetClusterInfoRequest func (s *GrpcServer) GetMinTS( ctx context.Context, request *pdpb.GetMinTSRequest, ) (*pdpb.GetMinTSResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetMinTS(ctx, request) @@ -295,10 +294,7 @@ func (s *GrpcServer) GetMinTS( return rsp.(*pdpb.GetMinTSResponse), nil } - var ( - minTS *pdpb.Timestamp - err error - ) + var minTS *pdpb.Timestamp if s.IsServiceIndependent(constant.TSOServiceName) { minTS, err = s.GetMinTSFromTSOService() } else { @@ -434,14 +430,12 @@ func (s *GrpcServer) getMinTSFromSingleServer( // GetMembers implements gRPC PDServer. func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb.GetMembersResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } // Here we purposely do not check the cluster ID because the client does not know the correct cluster ID // at startup and needs to get the cluster ID with the first request (i.e. GetMembers). @@ -484,14 +478,12 @@ func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb // Tso implements gRPC PDServer. func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return err + } + if done != nil { + defer done() } if s.IsServiceIndependent(constant.TSOServiceName) { return s.forwardTSO(stream) @@ -602,14 +594,12 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { // Bootstrap implements gRPC PDServer. func (s *GrpcServer) Bootstrap(ctx context.Context, request *pdpb.BootstrapRequest) (*pdpb.BootstrapResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).Bootstrap(ctx, request) @@ -644,14 +634,12 @@ func (s *GrpcServer) Bootstrap(ctx context.Context, request *pdpb.BootstrapReque // IsBootstrapped implements gRPC PDServer. func (s *GrpcServer) IsBootstrapped(ctx context.Context, request *pdpb.IsBootstrappedRequest) (*pdpb.IsBootstrappedResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).IsBootstrapped(ctx, request) @@ -671,14 +659,12 @@ func (s *GrpcServer) IsBootstrapped(ctx context.Context, request *pdpb.IsBootstr // AllocID implements gRPC PDServer. func (s *GrpcServer) AllocID(ctx context.Context, request *pdpb.AllocIDRequest) (*pdpb.AllocIDResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).AllocID(ctx, request) @@ -705,15 +691,14 @@ func (s *GrpcServer) AllocID(ctx context.Context, request *pdpb.AllocIDRequest) // IsSnapshotRecovering implements gRPC PDServer. func (s *GrpcServer) IsSnapshotRecovering(ctx context.Context, _ *pdpb.IsSnapshotRecoveringRequest) (*pdpb.IsSnapshotRecoveringResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err } + if done != nil { + defer done() + } + if s.IsClosed() { return nil, errs.ErrNotStarted } @@ -732,14 +717,12 @@ func (s *GrpcServer) IsSnapshotRecovering(ctx context.Context, _ *pdpb.IsSnapsho // GetStore implements gRPC PDServer. func (s *GrpcServer) GetStore(ctx context.Context, request *pdpb.GetStoreRequest) (*pdpb.GetStoreResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetStore(ctx, request) @@ -786,14 +769,12 @@ func checkStore(rc *cluster.RaftCluster, storeID uint64) *pdpb.Error { // PutStore implements gRPC PDServer. func (s *GrpcServer) PutStore(ctx context.Context, request *pdpb.PutStoreRequest) (*pdpb.PutStoreResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).PutStore(ctx, request) @@ -841,14 +822,12 @@ func (s *GrpcServer) PutStore(ctx context.Context, request *pdpb.PutStoreRequest // GetAllStores implements gRPC PDServer. func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStoresRequest) (*pdpb.GetAllStoresResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetAllStores(ctx, request) @@ -884,14 +863,12 @@ func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStore // StoreHeartbeat implements gRPC PDServer. func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHeartbeatRequest) (*pdpb.StoreHeartbeatResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).StoreHeartbeat(ctx, request) @@ -1063,14 +1040,12 @@ func (s *GrpcServer) ReportBuckets(stream pdpb.PD_ReportBucketsServer) error { cancel() } }() - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return err + } + if done != nil { + defer done() } for { request, err := server.recv() @@ -1179,14 +1154,12 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error cancel() } }() - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return err + } + if done != nil { + defer done() } for { request, err := server.Recv() @@ -1387,14 +1360,12 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque failpoint.Inject("rateLimit", func() { failpoint.Return(nil, errs.ErrGRPCRateLimitExceeded(errs.ErrRateLimitExceeded)) }) - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetRegion(ctx, request) @@ -1449,14 +1420,12 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque // GetPrevRegion implements gRPC PDServer func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionRequest) (*pdpb.GetRegionResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetPrevRegion(ctx, request) @@ -1506,14 +1475,12 @@ func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionR // GetRegionByID implements gRPC PDServer. func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionByIDRequest) (*pdpb.GetRegionResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetRegionByID(ctx, request) @@ -1566,14 +1533,12 @@ func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionB // Deprecated: use BatchScanRegions instead. // ScanRegions implements gRPC PDServer. func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsRequest) (*pdpb.ScanRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).ScanRegions(ctx, request) @@ -1622,14 +1587,12 @@ func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsR // BatchScanRegions implements gRPC PDServer. func (s *GrpcServer) BatchScanRegions(ctx context.Context, request *pdpb.BatchScanRegionsRequest) (*pdpb.BatchScanRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).BatchScanRegions(ctx, request) @@ -1712,14 +1675,12 @@ func (s *GrpcServer) BatchScanRegions(ctx context.Context, request *pdpb.BatchSc // AskSplit implements gRPC PDServer. func (s *GrpcServer) AskSplit(ctx context.Context, request *pdpb.AskSplitRequest) (*pdpb.AskSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).AskSplit(ctx, request) @@ -1756,14 +1717,12 @@ func (s *GrpcServer) AskSplit(ctx context.Context, request *pdpb.AskSplitRequest // AskBatchSplit implements gRPC PDServer. func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSplitRequest) (*pdpb.AskBatchSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } rc := s.GetRaftCluster() @@ -1830,14 +1789,12 @@ func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSp // ReportSplit implements gRPC PDServer. func (s *GrpcServer) ReportSplit(ctx context.Context, request *pdpb.ReportSplitRequest) (*pdpb.ReportSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).ReportSplit(ctx, request) @@ -1852,7 +1809,7 @@ func (s *GrpcServer) ReportSplit(ctx context.Context, request *pdpb.ReportSplitR if rc == nil { return &pdpb.ReportSplitResponse{Header: notBootstrappedHeader()}, nil } - _, err := rc.HandleReportSplit(request) + _, err = rc.HandleReportSplit(request) if err != nil { return &pdpb.ReportSplitResponse{ Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), @@ -1866,14 +1823,12 @@ func (s *GrpcServer) ReportSplit(ctx context.Context, request *pdpb.ReportSplitR // ReportBatchSplit implements gRPC PDServer. func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportBatchSplitRequest) (*pdpb.ReportBatchSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).ReportBatchSplit(ctx, request) @@ -1888,7 +1843,7 @@ func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportB if rc == nil { return &pdpb.ReportBatchSplitResponse{Header: notBootstrappedHeader()}, nil } - _, err := rc.HandleBatchReportSplit(request) + _, err = rc.HandleBatchReportSplit(request) if err != nil { return &pdpb.ReportBatchSplitResponse{ Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, @@ -1903,14 +1858,12 @@ func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportB // GetClusterConfig implements gRPC PDServer. func (s *GrpcServer) GetClusterConfig(ctx context.Context, request *pdpb.GetClusterConfigRequest) (*pdpb.GetClusterConfigResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetClusterConfig(ctx, request) @@ -1933,14 +1886,12 @@ func (s *GrpcServer) GetClusterConfig(ctx context.Context, request *pdpb.GetClus // PutClusterConfig implements gRPC PDServer. func (s *GrpcServer) PutClusterConfig(ctx context.Context, request *pdpb.PutClusterConfigRequest) (*pdpb.PutClusterConfigResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).PutClusterConfig(ctx, request) @@ -1972,14 +1923,12 @@ func (s *GrpcServer) PutClusterConfig(ctx context.Context, request *pdpb.PutClus // ScatterRegion implements gRPC PDServer. func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterRegionRequest) (*pdpb.ScatterRegionResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } rc := s.GetRaftCluster() @@ -2084,14 +2033,12 @@ func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterReg // GetGCSafePoint implements gRPC PDServer. func (s *GrpcServer) GetGCSafePoint(ctx context.Context, request *pdpb.GetGCSafePointRequest) (*pdpb.GetGCSafePointResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetGCSafePoint(ctx, request) @@ -2123,14 +2070,12 @@ func (s *GrpcServer) SyncRegions(stream pdpb.PD_SyncRegionsServer) error { if s.IsClosed() || s.cluster == nil { return errs.ErrNotStarted } - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return err + } + if done != nil { + defer done() } ctx := s.cluster.Context() if ctx == nil { @@ -2141,14 +2086,12 @@ func (s *GrpcServer) SyncRegions(stream pdpb.PD_SyncRegionsServer) error { // UpdateGCSafePoint implements gRPC PDServer. func (s *GrpcServer) UpdateGCSafePoint(ctx context.Context, request *pdpb.UpdateGCSafePointRequest) (*pdpb.UpdateGCSafePointResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).UpdateGCSafePoint(ctx, request) @@ -2188,14 +2131,12 @@ func (s *GrpcServer) UpdateGCSafePoint(ctx context.Context, request *pdpb.Update // UpdateServiceGCSafePoint update the safepoint for specific service func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb.UpdateServiceGCSafePointRequest) (*pdpb.UpdateServiceGCSafePointResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).UpdateServiceGCSafePoint(ctx, request) @@ -2242,14 +2183,12 @@ func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb // GetOperator gets information about the operator belonging to the specify region. func (s *GrpcServer) GetOperator(ctx context.Context, request *pdpb.GetOperatorRequest) (*pdpb.GetOperatorResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } rc := s.GetRaftCluster() @@ -2454,14 +2393,12 @@ func (*GrpcServer) SyncMaxTS(_ context.Context, _ *pdpb.SyncMaxTSRequest) (*pdpb // SplitRegions split regions by the given split keys func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegionsRequest) (*pdpb.SplitRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } rc := s.GetRaftCluster() @@ -2518,14 +2455,12 @@ func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegion // Only regions which split successfully will be scattered. // scatterFinishedPercentage indicates the percentage of successfully split regions that are scattered. func (s *GrpcServer) SplitAndScatterRegions(ctx context.Context, request *pdpb.SplitAndScatterRegionsRequest) (*pdpb.SplitAndScatterRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).SplitAndScatterRegions(ctx, request) @@ -2589,14 +2524,12 @@ func (s *GrpcServer) StoreGlobalConfig(_ context.Context, request *pdpb.StoreGlo if s.client == nil { return nil, errs.ErrEtcdNotStarted } - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } configPath := request.GetConfigPath() if configPath == "" { @@ -2635,14 +2568,12 @@ func (s *GrpcServer) LoadGlobalConfig(ctx context.Context, request *pdpb.LoadGlo if s.client == nil { return nil, errs.ErrEtcdNotStarted } - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } configPath := request.GetConfigPath() if configPath == "" { @@ -2683,14 +2614,12 @@ func (s *GrpcServer) WatchGlobalConfig(req *pdpb.WatchGlobalConfigRequest, serve if s.client == nil { return errs.ErrEtcdNotStarted } - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return err + } + if done != nil { + defer done() } ctx, cancel := context.WithCancel(server.Context()) defer cancel() @@ -2780,14 +2709,12 @@ func (s *GrpcServer) handleDamagedStore(stats *pdpb.StoreStats) { // ReportMinResolvedTS implements gRPC PDServer. func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.ReportMinResolvedTsRequest) (*pdpb.ReportMinResolvedTsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).ReportMinResolvedTS(ctx, request) @@ -2818,14 +2745,12 @@ func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.Repo // SetExternalTimestamp implements gRPC PDServer. func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.SetExternalTimestampRequest) (*pdpb.SetExternalTimestampResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).SetExternalTimestamp(ctx, request) @@ -2854,14 +2779,12 @@ func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.Set // GetExternalTimestamp implements gRPC PDServer. func (s *GrpcServer) GetExternalTimestamp(ctx context.Context, request *pdpb.GetExternalTimestampRequest) (*pdpb.GetExternalTimestampResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + done, err := s.rateLimitCheck() + if err != nil { + return nil, err + } + if done != nil { + defer done() } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetExternalTimestamp(ctx, request) @@ -2879,8 +2802,21 @@ func (s *GrpcServer) GetExternalTimestamp(ctx context.Context, request *pdpb.Get }, nil } -func currentFunction() string { - counter, _, _, _ := runtime.Caller(1) +func getCaller(skip int) string { + counter, _, _, _ := runtime.Caller(skip) s := strings.Split(runtime.FuncForPC(counter).Name(), ".") return s[len(s)-1] } + +func (s *GrpcServer) rateLimitCheck() (done ratelimit.DoneFunc, err error) { + if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { + fName := getCaller(2) + limiter := s.GetGRPCRateLimiter() + if done, err = limiter.Allow(fName); err == nil { + return + } + err = errs.ErrGRPCRateLimitExceeded(err) + return + } + return +}