From b7e9ad553166b30997cba0fa74fd8160dce930df Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Tue, 14 Jan 2025 15:48:05 +0800 Subject: [PATCH 1/2] use rateLimitCheck to reduce code Signed-off-by: Ryan Leung --- server/grpc_service.go | 397 ++++++++++------------------------------- 1 file changed, 91 insertions(+), 306 deletions(-) diff --git a/server/grpc_service.go b/server/grpc_service.go index 55639474ad0..2b37552b9c4 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -277,14 +277,8 @@ func (s *GrpcServer) GetClusterInfo(context.Context, *pdpb.GetClusterInfoRequest func (s *GrpcServer) GetMinTS( ctx context.Context, request *pdpb.GetMinTSRequest, ) (*pdpb.GetMinTSResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetMinTS(ctx, request) @@ -434,14 +428,8 @@ func (s *GrpcServer) getMinTSFromSingleServer( // GetMembers implements gRPC PDServer. func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb.GetMembersResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } // Here we purposely do not check the cluster ID because the client does not know the correct cluster ID // at startup and needs to get the cluster ID with the first request (i.e. GetMembers). @@ -484,14 +472,8 @@ func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb // Tso implements gRPC PDServer. func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return err } if s.IsServiceIndependent(constant.TSOServiceName) { return s.forwardTSO(stream) @@ -602,14 +584,8 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { // Bootstrap implements gRPC PDServer. func (s *GrpcServer) Bootstrap(ctx context.Context, request *pdpb.BootstrapRequest) (*pdpb.BootstrapResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).Bootstrap(ctx, request) @@ -644,14 +620,8 @@ func (s *GrpcServer) Bootstrap(ctx context.Context, request *pdpb.BootstrapReque // IsBootstrapped implements gRPC PDServer. func (s *GrpcServer) IsBootstrapped(ctx context.Context, request *pdpb.IsBootstrappedRequest) (*pdpb.IsBootstrappedResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).IsBootstrapped(ctx, request) @@ -671,14 +641,8 @@ func (s *GrpcServer) IsBootstrapped(ctx context.Context, request *pdpb.IsBootstr // AllocID implements gRPC PDServer. func (s *GrpcServer) AllocID(ctx context.Context, request *pdpb.AllocIDRequest) (*pdpb.AllocIDResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).AllocID(ctx, request) @@ -705,14 +669,8 @@ func (s *GrpcServer) AllocID(ctx context.Context, request *pdpb.AllocIDRequest) // IsSnapshotRecovering implements gRPC PDServer. func (s *GrpcServer) IsSnapshotRecovering(ctx context.Context, _ *pdpb.IsSnapshotRecoveringRequest) (*pdpb.IsSnapshotRecoveringResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } if s.IsClosed() { return nil, errs.ErrNotStarted @@ -732,14 +690,8 @@ func (s *GrpcServer) IsSnapshotRecovering(ctx context.Context, _ *pdpb.IsSnapsho // GetStore implements gRPC PDServer. func (s *GrpcServer) GetStore(ctx context.Context, request *pdpb.GetStoreRequest) (*pdpb.GetStoreResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetStore(ctx, request) @@ -786,14 +738,8 @@ func checkStore(rc *cluster.RaftCluster, storeID uint64) *pdpb.Error { // PutStore implements gRPC PDServer. func (s *GrpcServer) PutStore(ctx context.Context, request *pdpb.PutStoreRequest) (*pdpb.PutStoreResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).PutStore(ctx, request) @@ -841,14 +787,8 @@ func (s *GrpcServer) PutStore(ctx context.Context, request *pdpb.PutStoreRequest // GetAllStores implements gRPC PDServer. func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStoresRequest) (*pdpb.GetAllStoresResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetAllStores(ctx, request) @@ -884,14 +824,8 @@ func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStore // StoreHeartbeat implements gRPC PDServer. func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHeartbeatRequest) (*pdpb.StoreHeartbeatResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).StoreHeartbeat(ctx, request) @@ -1063,14 +997,8 @@ func (s *GrpcServer) ReportBuckets(stream pdpb.PD_ReportBucketsServer) error { cancel() } }() - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return err } for { request, err := server.recv() @@ -1179,14 +1107,8 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error cancel() } }() - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return err } for { request, err := server.Recv() @@ -1387,14 +1309,8 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque failpoint.Inject("rateLimit", func() { failpoint.Return(nil, errs.ErrGRPCRateLimitExceeded(errs.ErrRateLimitExceeded)) }) - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetRegion(ctx, request) @@ -1449,14 +1365,8 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque // GetPrevRegion implements gRPC PDServer func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionRequest) (*pdpb.GetRegionResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetPrevRegion(ctx, request) @@ -1506,14 +1416,8 @@ func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionR // GetRegionByID implements gRPC PDServer. func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionByIDRequest) (*pdpb.GetRegionResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetRegionByID(ctx, request) @@ -1566,14 +1470,8 @@ func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionB // Deprecated: use BatchScanRegions instead. // ScanRegions implements gRPC PDServer. func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsRequest) (*pdpb.ScanRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).ScanRegions(ctx, request) @@ -1622,14 +1520,8 @@ func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsR // BatchScanRegions implements gRPC PDServer. func (s *GrpcServer) BatchScanRegions(ctx context.Context, request *pdpb.BatchScanRegionsRequest) (*pdpb.BatchScanRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).BatchScanRegions(ctx, request) @@ -1712,14 +1604,8 @@ func (s *GrpcServer) BatchScanRegions(ctx context.Context, request *pdpb.BatchSc // AskSplit implements gRPC PDServer. func (s *GrpcServer) AskSplit(ctx context.Context, request *pdpb.AskSplitRequest) (*pdpb.AskSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).AskSplit(ctx, request) @@ -1756,14 +1642,8 @@ func (s *GrpcServer) AskSplit(ctx context.Context, request *pdpb.AskSplitRequest // AskBatchSplit implements gRPC PDServer. func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSplitRequest) (*pdpb.AskBatchSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } rc := s.GetRaftCluster() @@ -1830,14 +1710,8 @@ func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSp // ReportSplit implements gRPC PDServer. func (s *GrpcServer) ReportSplit(ctx context.Context, request *pdpb.ReportSplitRequest) (*pdpb.ReportSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).ReportSplit(ctx, request) @@ -1866,14 +1740,8 @@ func (s *GrpcServer) ReportSplit(ctx context.Context, request *pdpb.ReportSplitR // ReportBatchSplit implements gRPC PDServer. func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportBatchSplitRequest) (*pdpb.ReportBatchSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).ReportBatchSplit(ctx, request) @@ -1903,14 +1771,8 @@ func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportB // GetClusterConfig implements gRPC PDServer. func (s *GrpcServer) GetClusterConfig(ctx context.Context, request *pdpb.GetClusterConfigRequest) (*pdpb.GetClusterConfigResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetClusterConfig(ctx, request) @@ -1933,14 +1795,8 @@ func (s *GrpcServer) GetClusterConfig(ctx context.Context, request *pdpb.GetClus // PutClusterConfig implements gRPC PDServer. func (s *GrpcServer) PutClusterConfig(ctx context.Context, request *pdpb.PutClusterConfigRequest) (*pdpb.PutClusterConfigResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).PutClusterConfig(ctx, request) @@ -1972,14 +1828,8 @@ func (s *GrpcServer) PutClusterConfig(ctx context.Context, request *pdpb.PutClus // ScatterRegion implements gRPC PDServer. func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterRegionRequest) (*pdpb.ScatterRegionResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } rc := s.GetRaftCluster() @@ -2084,14 +1934,8 @@ func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterReg // GetGCSafePoint implements gRPC PDServer. func (s *GrpcServer) GetGCSafePoint(ctx context.Context, request *pdpb.GetGCSafePointRequest) (*pdpb.GetGCSafePointResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetGCSafePoint(ctx, request) @@ -2123,14 +1967,8 @@ func (s *GrpcServer) SyncRegions(stream pdpb.PD_SyncRegionsServer) error { if s.IsClosed() || s.cluster == nil { return errs.ErrNotStarted } - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return err } ctx := s.cluster.Context() if ctx == nil { @@ -2141,14 +1979,8 @@ func (s *GrpcServer) SyncRegions(stream pdpb.PD_SyncRegionsServer) error { // UpdateGCSafePoint implements gRPC PDServer. func (s *GrpcServer) UpdateGCSafePoint(ctx context.Context, request *pdpb.UpdateGCSafePointRequest) (*pdpb.UpdateGCSafePointResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).UpdateGCSafePoint(ctx, request) @@ -2188,14 +2020,8 @@ func (s *GrpcServer) UpdateGCSafePoint(ctx context.Context, request *pdpb.Update // UpdateServiceGCSafePoint update the safepoint for specific service func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb.UpdateServiceGCSafePointRequest) (*pdpb.UpdateServiceGCSafePointResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).UpdateServiceGCSafePoint(ctx, request) @@ -2242,14 +2068,8 @@ func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb // GetOperator gets information about the operator belonging to the specify region. func (s *GrpcServer) GetOperator(ctx context.Context, request *pdpb.GetOperatorRequest) (*pdpb.GetOperatorResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } rc := s.GetRaftCluster() @@ -2454,14 +2274,8 @@ func (*GrpcServer) SyncMaxTS(_ context.Context, _ *pdpb.SyncMaxTSRequest) (*pdpb // SplitRegions split regions by the given split keys func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegionsRequest) (*pdpb.SplitRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } rc := s.GetRaftCluster() @@ -2518,14 +2332,8 @@ func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegion // Only regions which split successfully will be scattered. // scatterFinishedPercentage indicates the percentage of successfully split regions that are scattered. func (s *GrpcServer) SplitAndScatterRegions(ctx context.Context, request *pdpb.SplitAndScatterRegionsRequest) (*pdpb.SplitAndScatterRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).SplitAndScatterRegions(ctx, request) @@ -2589,14 +2397,8 @@ func (s *GrpcServer) StoreGlobalConfig(_ context.Context, request *pdpb.StoreGlo if s.client == nil { return nil, errs.ErrEtcdNotStarted } - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } configPath := request.GetConfigPath() if configPath == "" { @@ -2635,14 +2437,8 @@ func (s *GrpcServer) LoadGlobalConfig(ctx context.Context, request *pdpb.LoadGlo if s.client == nil { return nil, errs.ErrEtcdNotStarted } - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } configPath := request.GetConfigPath() if configPath == "" { @@ -2683,14 +2479,8 @@ func (s *GrpcServer) WatchGlobalConfig(req *pdpb.WatchGlobalConfigRequest, serve if s.client == nil { return errs.ErrEtcdNotStarted } - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return err } ctx, cancel := context.WithCancel(server.Context()) defer cancel() @@ -2780,14 +2570,8 @@ func (s *GrpcServer) handleDamagedStore(stats *pdpb.StoreStats) { // ReportMinResolvedTS implements gRPC PDServer. func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.ReportMinResolvedTsRequest) (*pdpb.ReportMinResolvedTsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).ReportMinResolvedTS(ctx, request) @@ -2818,14 +2602,8 @@ func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.Repo // SetExternalTimestamp implements gRPC PDServer. func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.SetExternalTimestampRequest) (*pdpb.SetExternalTimestampResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).SetExternalTimestamp(ctx, request) @@ -2854,14 +2632,8 @@ func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.Set // GetExternalTimestamp implements gRPC PDServer. func (s *GrpcServer) GetExternalTimestamp(ctx context.Context, request *pdpb.GetExternalTimestampRequest) (*pdpb.GetExternalTimestampResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return nil, errs.ErrGRPCRateLimitExceeded(err) - } + if err := s.rateLimitCheck(); err != nil { + return nil, err } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetExternalTimestamp(ctx, request) @@ -2879,8 +2651,21 @@ func (s *GrpcServer) GetExternalTimestamp(ctx context.Context, request *pdpb.Get }, nil } -func currentFunction() string { - counter, _, _, _ := runtime.Caller(1) +func getCaller(skip int) string { + counter, _, _, _ := runtime.Caller(skip) s := strings.Split(runtime.FuncForPC(counter).Name(), ".") return s[len(s)-1] } + +func (s *GrpcServer) rateLimitCheck() error { + if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { + fName := getCaller(2) + limiter := s.GetGRPCRateLimiter() + if done, err := limiter.Allow(fName); err == nil { + defer done() + } else { + return errs.ErrGRPCRateLimitExceeded(err) + } + } + return nil +} From 487aa630770bedd624de3bb7f50b21518773b2ea Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Tue, 14 Jan 2025 18:34:09 +0800 Subject: [PATCH 2/2] fix done Signed-off-by: Ryan Leung --- server/grpc_service.go | 251 +++++++++++++++++++++++++++++++++-------- 1 file changed, 201 insertions(+), 50 deletions(-) diff --git a/server/grpc_service.go b/server/grpc_service.go index 2b37552b9c4..649d02a37b4 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -44,6 +44,7 @@ import ( "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/mcs/utils/constant" + "github.com/tikv/pd/pkg/ratelimit" "github.com/tikv/pd/pkg/storage/endpoint" "github.com/tikv/pd/pkg/storage/kv" "github.com/tikv/pd/pkg/utils/grpcutil" @@ -277,9 +278,13 @@ func (s *GrpcServer) GetClusterInfo(context.Context, *pdpb.GetClusterInfoRequest func (s *GrpcServer) GetMinTS( ctx context.Context, request *pdpb.GetMinTSRequest, ) (*pdpb.GetMinTSResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetMinTS(ctx, request) } @@ -289,10 +294,7 @@ func (s *GrpcServer) GetMinTS( return rsp.(*pdpb.GetMinTSResponse), nil } - var ( - minTS *pdpb.Timestamp - err error - ) + var minTS *pdpb.Timestamp if s.IsServiceIndependent(constant.TSOServiceName) { minTS, err = s.GetMinTSFromTSOService() } else { @@ -428,9 +430,13 @@ func (s *GrpcServer) getMinTSFromSingleServer( // GetMembers implements gRPC PDServer. func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb.GetMembersResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } // Here we purposely do not check the cluster ID because the client does not know the correct cluster ID // at startup and needs to get the cluster ID with the first request (i.e. GetMembers). if s.IsClosed() { @@ -472,9 +478,13 @@ func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb // Tso implements gRPC PDServer. func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return err } + if done != nil { + defer done() + } if s.IsServiceIndependent(constant.TSOServiceName) { return s.forwardTSO(stream) } @@ -584,9 +594,13 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { // Bootstrap implements gRPC PDServer. func (s *GrpcServer) Bootstrap(ctx context.Context, request *pdpb.BootstrapRequest) (*pdpb.BootstrapResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).Bootstrap(ctx, request) } @@ -620,9 +634,13 @@ func (s *GrpcServer) Bootstrap(ctx context.Context, request *pdpb.BootstrapReque // IsBootstrapped implements gRPC PDServer. func (s *GrpcServer) IsBootstrapped(ctx context.Context, request *pdpb.IsBootstrappedRequest) (*pdpb.IsBootstrappedResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).IsBootstrapped(ctx, request) } @@ -641,9 +659,13 @@ func (s *GrpcServer) IsBootstrapped(ctx context.Context, request *pdpb.IsBootstr // AllocID implements gRPC PDServer. func (s *GrpcServer) AllocID(ctx context.Context, request *pdpb.AllocIDRequest) (*pdpb.AllocIDResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).AllocID(ctx, request) } @@ -669,9 +691,14 @@ func (s *GrpcServer) AllocID(ctx context.Context, request *pdpb.AllocIDRequest) // IsSnapshotRecovering implements gRPC PDServer. func (s *GrpcServer) IsSnapshotRecovering(ctx context.Context, _ *pdpb.IsSnapshotRecoveringRequest) (*pdpb.IsSnapshotRecoveringResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } + if s.IsClosed() { return nil, errs.ErrNotStarted } @@ -690,9 +717,13 @@ func (s *GrpcServer) IsSnapshotRecovering(ctx context.Context, _ *pdpb.IsSnapsho // GetStore implements gRPC PDServer. func (s *GrpcServer) GetStore(ctx context.Context, request *pdpb.GetStoreRequest) (*pdpb.GetStoreResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetStore(ctx, request) } @@ -738,9 +769,13 @@ func checkStore(rc *cluster.RaftCluster, storeID uint64) *pdpb.Error { // PutStore implements gRPC PDServer. func (s *GrpcServer) PutStore(ctx context.Context, request *pdpb.PutStoreRequest) (*pdpb.PutStoreResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).PutStore(ctx, request) } @@ -787,9 +822,13 @@ func (s *GrpcServer) PutStore(ctx context.Context, request *pdpb.PutStoreRequest // GetAllStores implements gRPC PDServer. func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStoresRequest) (*pdpb.GetAllStoresResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetAllStores(ctx, request) } @@ -824,9 +863,13 @@ func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStore // StoreHeartbeat implements gRPC PDServer. func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHeartbeatRequest) (*pdpb.StoreHeartbeatResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).StoreHeartbeat(ctx, request) } @@ -997,9 +1040,13 @@ func (s *GrpcServer) ReportBuckets(stream pdpb.PD_ReportBucketsServer) error { cancel() } }() - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return err } + if done != nil { + defer done() + } for { request, err := server.recv() failpoint.Inject("grpcClientClosed", func() { @@ -1107,9 +1154,13 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error cancel() } }() - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return err } + if done != nil { + defer done() + } for { request, err := server.Recv() if err == io.EOF { @@ -1309,9 +1360,13 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque failpoint.Inject("rateLimit", func() { failpoint.Return(nil, errs.ErrGRPCRateLimitExceeded(errs.ErrRateLimitExceeded)) }) - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetRegion(ctx, request) } @@ -1365,9 +1420,13 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque // GetPrevRegion implements gRPC PDServer func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionRequest) (*pdpb.GetRegionResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetPrevRegion(ctx, request) } @@ -1416,9 +1475,13 @@ func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionR // GetRegionByID implements gRPC PDServer. func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionByIDRequest) (*pdpb.GetRegionResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetRegionByID(ctx, request) } @@ -1470,9 +1533,13 @@ func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionB // Deprecated: use BatchScanRegions instead. // ScanRegions implements gRPC PDServer. func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsRequest) (*pdpb.ScanRegionsResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).ScanRegions(ctx, request) } @@ -1520,9 +1587,13 @@ func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsR // BatchScanRegions implements gRPC PDServer. func (s *GrpcServer) BatchScanRegions(ctx context.Context, request *pdpb.BatchScanRegionsRequest) (*pdpb.BatchScanRegionsResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).BatchScanRegions(ctx, request) } @@ -1604,9 +1675,13 @@ func (s *GrpcServer) BatchScanRegions(ctx context.Context, request *pdpb.BatchSc // AskSplit implements gRPC PDServer. func (s *GrpcServer) AskSplit(ctx context.Context, request *pdpb.AskSplitRequest) (*pdpb.AskSplitResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).AskSplit(ctx, request) } @@ -1642,9 +1717,13 @@ func (s *GrpcServer) AskSplit(ctx context.Context, request *pdpb.AskSplitRequest // AskBatchSplit implements gRPC PDServer. func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSplitRequest) (*pdpb.AskBatchSplitResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } rc := s.GetRaftCluster() if rc == nil { @@ -1710,9 +1789,13 @@ func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSp // ReportSplit implements gRPC PDServer. func (s *GrpcServer) ReportSplit(ctx context.Context, request *pdpb.ReportSplitRequest) (*pdpb.ReportSplitResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).ReportSplit(ctx, request) } @@ -1726,7 +1809,7 @@ func (s *GrpcServer) ReportSplit(ctx context.Context, request *pdpb.ReportSplitR if rc == nil { return &pdpb.ReportSplitResponse{Header: notBootstrappedHeader()}, nil } - _, err := rc.HandleReportSplit(request) + _, err = rc.HandleReportSplit(request) if err != nil { return &pdpb.ReportSplitResponse{ Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), @@ -1740,9 +1823,13 @@ func (s *GrpcServer) ReportSplit(ctx context.Context, request *pdpb.ReportSplitR // ReportBatchSplit implements gRPC PDServer. func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportBatchSplitRequest) (*pdpb.ReportBatchSplitResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).ReportBatchSplit(ctx, request) } @@ -1756,7 +1843,7 @@ func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportB if rc == nil { return &pdpb.ReportBatchSplitResponse{Header: notBootstrappedHeader()}, nil } - _, err := rc.HandleBatchReportSplit(request) + _, err = rc.HandleBatchReportSplit(request) if err != nil { return &pdpb.ReportBatchSplitResponse{ Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, @@ -1771,9 +1858,13 @@ func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportB // GetClusterConfig implements gRPC PDServer. func (s *GrpcServer) GetClusterConfig(ctx context.Context, request *pdpb.GetClusterConfigRequest) (*pdpb.GetClusterConfigResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetClusterConfig(ctx, request) } @@ -1795,9 +1886,13 @@ func (s *GrpcServer) GetClusterConfig(ctx context.Context, request *pdpb.GetClus // PutClusterConfig implements gRPC PDServer. func (s *GrpcServer) PutClusterConfig(ctx context.Context, request *pdpb.PutClusterConfigRequest) (*pdpb.PutClusterConfigResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).PutClusterConfig(ctx, request) } @@ -1828,9 +1923,13 @@ func (s *GrpcServer) PutClusterConfig(ctx context.Context, request *pdpb.PutClus // ScatterRegion implements gRPC PDServer. func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterRegionRequest) (*pdpb.ScatterRegionResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } rc := s.GetRaftCluster() if rc == nil { @@ -1934,9 +2033,13 @@ func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterReg // GetGCSafePoint implements gRPC PDServer. func (s *GrpcServer) GetGCSafePoint(ctx context.Context, request *pdpb.GetGCSafePointRequest) (*pdpb.GetGCSafePointResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetGCSafePoint(ctx, request) } @@ -1967,9 +2070,13 @@ func (s *GrpcServer) SyncRegions(stream pdpb.PD_SyncRegionsServer) error { if s.IsClosed() || s.cluster == nil { return errs.ErrNotStarted } - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return err } + if done != nil { + defer done() + } ctx := s.cluster.Context() if ctx == nil { return errs.ErrNotStarted @@ -1979,9 +2086,13 @@ func (s *GrpcServer) SyncRegions(stream pdpb.PD_SyncRegionsServer) error { // UpdateGCSafePoint implements gRPC PDServer. func (s *GrpcServer) UpdateGCSafePoint(ctx context.Context, request *pdpb.UpdateGCSafePointRequest) (*pdpb.UpdateGCSafePointResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).UpdateGCSafePoint(ctx, request) } @@ -2020,9 +2131,13 @@ func (s *GrpcServer) UpdateGCSafePoint(ctx context.Context, request *pdpb.Update // UpdateServiceGCSafePoint update the safepoint for specific service func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb.UpdateServiceGCSafePointRequest) (*pdpb.UpdateServiceGCSafePointResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).UpdateServiceGCSafePoint(ctx, request) } @@ -2068,9 +2183,13 @@ func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb // GetOperator gets information about the operator belonging to the specify region. func (s *GrpcServer) GetOperator(ctx context.Context, request *pdpb.GetOperatorRequest) (*pdpb.GetOperatorResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } rc := s.GetRaftCluster() if rc == nil { @@ -2274,9 +2393,13 @@ func (*GrpcServer) SyncMaxTS(_ context.Context, _ *pdpb.SyncMaxTSRequest) (*pdpb // SplitRegions split regions by the given split keys func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegionsRequest) (*pdpb.SplitRegionsResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } rc := s.GetRaftCluster() if rc == nil { @@ -2332,9 +2455,13 @@ func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegion // Only regions which split successfully will be scattered. // scatterFinishedPercentage indicates the percentage of successfully split regions that are scattered. func (s *GrpcServer) SplitAndScatterRegions(ctx context.Context, request *pdpb.SplitAndScatterRegionsRequest) (*pdpb.SplitAndScatterRegionsResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).SplitAndScatterRegions(ctx, request) } @@ -2397,9 +2524,13 @@ func (s *GrpcServer) StoreGlobalConfig(_ context.Context, request *pdpb.StoreGlo if s.client == nil { return nil, errs.ErrEtcdNotStarted } - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } configPath := request.GetConfigPath() if configPath == "" { configPath = globalConfigPath @@ -2437,9 +2568,13 @@ func (s *GrpcServer) LoadGlobalConfig(ctx context.Context, request *pdpb.LoadGlo if s.client == nil { return nil, errs.ErrEtcdNotStarted } - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } configPath := request.GetConfigPath() if configPath == "" { configPath = globalConfigPath @@ -2479,9 +2614,13 @@ func (s *GrpcServer) WatchGlobalConfig(req *pdpb.WatchGlobalConfigRequest, serve if s.client == nil { return errs.ErrEtcdNotStarted } - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return err } + if done != nil { + defer done() + } ctx, cancel := context.WithCancel(server.Context()) defer cancel() configPath := req.GetConfigPath() @@ -2570,9 +2709,13 @@ func (s *GrpcServer) handleDamagedStore(stats *pdpb.StoreStats) { // ReportMinResolvedTS implements gRPC PDServer. func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.ReportMinResolvedTsRequest) (*pdpb.ReportMinResolvedTsResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).ReportMinResolvedTS(ctx, request) } @@ -2602,9 +2745,13 @@ func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.Repo // SetExternalTimestamp implements gRPC PDServer. func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.SetExternalTimestampRequest) (*pdpb.SetExternalTimestampResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).SetExternalTimestamp(ctx, request) } @@ -2632,9 +2779,13 @@ func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.Set // GetExternalTimestamp implements gRPC PDServer. func (s *GrpcServer) GetExternalTimestamp(ctx context.Context, request *pdpb.GetExternalTimestampRequest) (*pdpb.GetExternalTimestampResponse, error) { - if err := s.rateLimitCheck(); err != nil { + done, err := s.rateLimitCheck() + if err != nil { return nil, err } + if done != nil { + defer done() + } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { return pdpb.NewPDClient(client).GetExternalTimestamp(ctx, request) } @@ -2657,15 +2808,15 @@ func getCaller(skip int) string { return s[len(s)-1] } -func (s *GrpcServer) rateLimitCheck() error { +func (s *GrpcServer) rateLimitCheck() (done ratelimit.DoneFunc, err error) { if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := getCaller(2) limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { - return errs.ErrGRPCRateLimitExceeded(err) + if done, err = limiter.Allow(fName); err == nil { + return } + err = errs.ErrGRPCRateLimitExceeded(err) + return } - return nil + return }