diff --git a/server/grpc_service.go b/server/grpc_service.go index 55639474ad0..2b37552b9c4 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -277,14 +277,8 @@ func (s *GrpcServer) GetClusterInfo(context.Context, *pdpb.GetClusterInfoRequest func (s *GrpcServer) GetMinTS( ctx context.Context, request *pdpb.GetMinTSRequest, ) (*pdpb.GetMinTSResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetMinTS(ctx, request) @@ -434,14 +428,8 @@ func (s *GrpcServer) getMinTSFromSingleServer( // GetMembers implements gRPC PDServer. func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb.GetMembersResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } // Here we purposely do not check the cluster ID because the client does not know the correct cluster ID // at startup and needs to get the cluster ID with the first request (i.e. GetMembers). @@ -484,14 +472,8 @@ func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb // Tso implements gRPC PDServer. func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return err } if s.IsServiceIndependent(constant.TSOServiceName) { return s.forwardTSO(stream) @@ -602,14 +584,8 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { // Bootstrap implements gRPC PDServer. func (s *GrpcServer) Bootstrap(ctx context.Context, request *pdpb.BootstrapRequest) (*pdpb.BootstrapResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).Bootstrap(ctx, request) @@ -644,14 +620,8 @@ func (s *GrpcServer) Bootstrap(ctx context.Context, request *pdpb.BootstrapReque // IsBootstrapped implements gRPC PDServer. func (s *GrpcServer) IsBootstrapped(ctx context.Context, request *pdpb.IsBootstrappedRequest) (*pdpb.IsBootstrappedResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).IsBootstrapped(ctx, request) @@ -671,14 +641,8 @@ func (s *GrpcServer) IsBootstrapped(ctx context.Context, request *pdpb.IsBootstr // AllocID implements gRPC PDServer. func (s *GrpcServer) AllocID(ctx context.Context, request *pdpb.AllocIDRequest) (*pdpb.AllocIDResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).AllocID(ctx, request) @@ -705,14 +669,8 @@ func (s *GrpcServer) AllocID(ctx context.Context, request *pdpb.AllocIDRequest) // IsSnapshotRecovering implements gRPC PDServer. func (s *GrpcServer) IsSnapshotRecovering(ctx context.Context, _ *pdpb.IsSnapshotRecoveringRequest) (*pdpb.IsSnapshotRecoveringResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } if s.IsClosed() { return nil, errs.ErrNotStarted @@ -732,14 +690,8 @@ func (s *GrpcServer) IsSnapshotRecovering(ctx context.Context, _ *pdpb.IsSnapsho // GetStore implements gRPC PDServer. func (s *GrpcServer) GetStore(ctx context.Context, request *pdpb.GetStoreRequest) (*pdpb.GetStoreResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetStore(ctx, request) @@ -786,14 +738,8 @@ func checkStore(rc *cluster.RaftCluster, storeID uint64) *pdpb.Error { // PutStore implements gRPC PDServer. func (s *GrpcServer) PutStore(ctx context.Context, request *pdpb.PutStoreRequest) (*pdpb.PutStoreResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).PutStore(ctx, request) @@ -841,14 +787,8 @@ func (s *GrpcServer) PutStore(ctx context.Context, request *pdpb.PutStoreRequest // GetAllStores implements gRPC PDServer. func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStoresRequest) (*pdpb.GetAllStoresResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetAllStores(ctx, request) @@ -884,14 +824,8 @@ func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStore // StoreHeartbeat implements gRPC PDServer. func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHeartbeatRequest) (*pdpb.StoreHeartbeatResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).StoreHeartbeat(ctx, request) @@ -1063,14 +997,8 @@ func (s *GrpcServer) ReportBuckets(stream pdpb.PD_ReportBucketsServer) error { cancel() } }() - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return err } for { request, err := server.recv() @@ -1179,14 +1107,8 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error cancel() } }() - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return err } for { request, err := server.Recv() @@ -1387,14 +1309,8 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque failpoint.Inject("rateLimit", func() { failpoint.Return(nil, errs.ErrGRPCRateLimitExceeded(errs.ErrRateLimitExceeded)) }) - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetRegion(ctx, request) @@ -1449,14 +1365,8 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque // GetPrevRegion implements gRPC PDServer func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionRequest) (*pdpb.GetRegionResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetPrevRegion(ctx, request) @@ -1506,14 +1416,8 @@ func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionR // GetRegionByID implements gRPC PDServer. func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionByIDRequest) (*pdpb.GetRegionResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetRegionByID(ctx, request) @@ -1566,14 +1470,8 @@ func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionB // Deprecated: use BatchScanRegions instead. // ScanRegions implements gRPC PDServer. func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsRequest) (*pdpb.ScanRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).ScanRegions(ctx, request) @@ -1622,14 +1520,8 @@ func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsR // BatchScanRegions implements gRPC PDServer. func (s *GrpcServer) BatchScanRegions(ctx context.Context, request *pdpb.BatchScanRegionsRequest) (*pdpb.BatchScanRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).BatchScanRegions(ctx, request) @@ -1712,14 +1604,8 @@ func (s *GrpcServer) BatchScanRegions(ctx context.Context, request *pdpb.BatchSc // AskSplit implements gRPC PDServer. func (s *GrpcServer) AskSplit(ctx context.Context, request *pdpb.AskSplitRequest) (*pdpb.AskSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).AskSplit(ctx, request) @@ -1756,14 +1642,8 @@ func (s *GrpcServer) AskSplit(ctx context.Context, request *pdpb.AskSplitRequest // AskBatchSplit implements gRPC PDServer. func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSplitRequest) (*pdpb.AskBatchSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } rc := s.GetRaftCluster() @@ -1830,14 +1710,8 @@ func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSp // ReportSplit implements gRPC PDServer. func (s *GrpcServer) ReportSplit(ctx context.Context, request *pdpb.ReportSplitRequest) (*pdpb.ReportSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).ReportSplit(ctx, request) @@ -1866,14 +1740,8 @@ func (s *GrpcServer) ReportSplit(ctx context.Context, request *pdpb.ReportSplitR // ReportBatchSplit implements gRPC PDServer. func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportBatchSplitRequest) (*pdpb.ReportBatchSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).ReportBatchSplit(ctx, request) @@ -1903,14 +1771,8 @@ func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportB // GetClusterConfig implements gRPC PDServer. func (s *GrpcServer) GetClusterConfig(ctx context.Context, request *pdpb.GetClusterConfigRequest) (*pdpb.GetClusterConfigResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetClusterConfig(ctx, request) @@ -1933,14 +1795,8 @@ func (s *GrpcServer) GetClusterConfig(ctx context.Context, request *pdpb.GetClus // PutClusterConfig implements gRPC PDServer. func (s *GrpcServer) PutClusterConfig(ctx context.Context, request *pdpb.PutClusterConfigRequest) (*pdpb.PutClusterConfigResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).PutClusterConfig(ctx, request) @@ -1972,14 +1828,8 @@ func (s *GrpcServer) PutClusterConfig(ctx context.Context, request *pdpb.PutClus // ScatterRegion implements gRPC PDServer. func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterRegionRequest) (*pdpb.ScatterRegionResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } rc := s.GetRaftCluster() @@ -2084,14 +1934,8 @@ func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterReg // GetGCSafePoint implements gRPC PDServer. func (s *GrpcServer) GetGCSafePoint(ctx context.Context, request *pdpb.GetGCSafePointRequest) (*pdpb.GetGCSafePointResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetGCSafePoint(ctx, request) @@ -2123,14 +1967,8 @@ func (s *GrpcServer) SyncRegions(stream pdpb.PD_SyncRegionsServer) error { if s.IsClosed() || s.cluster == nil { return errs.ErrNotStarted } - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return err } ctx := s.cluster.Context() if ctx == nil { @@ -2141,14 +1979,8 @@ func (s *GrpcServer) SyncRegions(stream pdpb.PD_SyncRegionsServer) error { // UpdateGCSafePoint implements gRPC PDServer. func (s *GrpcServer) UpdateGCSafePoint(ctx context.Context, request *pdpb.UpdateGCSafePointRequest) (*pdpb.UpdateGCSafePointResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).UpdateGCSafePoint(ctx, request) @@ -2188,14 +2020,8 @@ func (s *GrpcServer) UpdateGCSafePoint(ctx context.Context, request *pdpb.Update // UpdateServiceGCSafePoint update the safepoint for specific service func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb.UpdateServiceGCSafePointRequest) (*pdpb.UpdateServiceGCSafePointResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).UpdateServiceGCSafePoint(ctx, request) @@ -2242,14 +2068,8 @@ func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb // GetOperator gets information about the operator belonging to the specify region. func (s *GrpcServer) GetOperator(ctx context.Context, request *pdpb.GetOperatorRequest) (*pdpb.GetOperatorResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } rc := s.GetRaftCluster() @@ -2454,14 +2274,8 @@ func (*GrpcServer) SyncMaxTS(_ context.Context, _ *pdpb.SyncMaxTSRequest) (*pdpb // SplitRegions split regions by the given split keys func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegionsRequest) (*pdpb.SplitRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } rc := s.GetRaftCluster() @@ -2518,14 +2332,8 @@ func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegion // Only regions which split successfully will be scattered. // scatterFinishedPercentage indicates the percentage of successfully split regions that are scattered. func (s *GrpcServer) SplitAndScatterRegions(ctx context.Context, request *pdpb.SplitAndScatterRegionsRequest) (*pdpb.SplitAndScatterRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).SplitAndScatterRegions(ctx, request) @@ -2589,14 +2397,8 @@ func (s *GrpcServer) StoreGlobalConfig(_ context.Context, request *pdpb.StoreGlo if s.client == nil { return nil, errs.ErrEtcdNotStarted } - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } configPath := request.GetConfigPath() if configPath == "" { @@ -2635,14 +2437,8 @@ func (s *GrpcServer) LoadGlobalConfig(ctx context.Context, request *pdpb.LoadGlo if s.client == nil { return nil, errs.ErrEtcdNotStarted } - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } configPath := request.GetConfigPath() if configPath == "" { @@ -2683,14 +2479,8 @@ func (s *GrpcServer) WatchGlobalConfig(req *pdpb.WatchGlobalConfigRequest, serve if s.client == nil { return errs.ErrEtcdNotStarted } - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return err } ctx, cancel := context.WithCancel(server.Context()) defer cancel() @@ -2780,14 +2570,8 @@ func (s *GrpcServer) handleDamagedStore(stats *pdpb.StoreStats) { // ReportMinResolvedTS implements gRPC PDServer. func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.ReportMinResolvedTsRequest) (*pdpb.ReportMinResolvedTsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).ReportMinResolvedTS(ctx, request) @@ -2818,14 +2602,8 @@ func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.Repo // SetExternalTimestamp implements gRPC PDServer. func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.SetExternalTimestampRequest) (*pdpb.SetExternalTimestampResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).SetExternalTimestamp(ctx, request) @@ -2854,14 +2632,8 @@ func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.Set // GetExternalTimestamp implements gRPC PDServer. func (s *GrpcServer) GetExternalTimestamp(ctx context.Context, request *pdpb.GetExternalTimestampRequest) (*pdpb.GetExternalTimestampResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetExternalTimestamp(ctx, request) @@ -2879,8 +2651,21 @@ func (s *GrpcServer) GetExternalTimestamp(ctx context.Context, request *pdpb.Get }, nil } -func currentFunction() string { - counter, _, _, _ := runtime.Caller(1) +func getCaller(skip int) string { + counter, _, _, _ := runtime.Caller(skip) s := strings.Split(runtime.FuncForPC(counter).Name(), ".") return s[len(s)-1] } + +func (s *GrpcServer) rateLimitCheck() error { + if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { + fName := getCaller(2) + limiter := s.GetGRPCRateLimiter() + if done, err := limiter.Allow(fName); err == nil { + defer done() + } else { + return errs.ErrGRPCRateLimitExceeded(err) + } + } + return nil +}