From 8dcd49720cd9999119d27212220bc0b03f82a75e Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Tue, 14 Nov 2023 11:29:14 +0800 Subject: [PATCH] *: Improve region forward (#7305) ref tikv/pd#5839 Signed-off-by: Ryan Leung Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- server/forward.go | 504 +++++++++++++++++++++++ server/grpc_service.go | 891 ++++++++++++----------------------------- 2 files changed, 750 insertions(+), 645 deletions(-) create mode 100644 server/forward.go diff --git a/server/forward.go b/server/forward.go new file mode 100644 index 00000000000..e765d442539 --- /dev/null +++ b/server/forward.go @@ -0,0 +1,504 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "context" + "io" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/kvproto/pkg/schedulingpb" + "github.com/pingcap/kvproto/pkg/tsopb" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/mcs/utils" + "github.com/tikv/pd/pkg/tso" + "github.com/tikv/pd/pkg/utils/grpcutil" + "github.com/tikv/pd/pkg/utils/logutil" + "github.com/tikv/pd/pkg/utils/tsoutil" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func (s *GrpcServer) forwardTSORequest( + ctx context.Context, + request *pdpb.TsoRequest, + forwardStream tsopb.TSO_TsoClient) (*tsopb.TsoResponse, error) { + tsopbReq := &tsopb.TsoRequest{ + Header: &tsopb.RequestHeader{ + ClusterId: request.GetHeader().GetClusterId(), + SenderId: request.GetHeader().GetSenderId(), + KeyspaceId: utils.DefaultKeyspaceID, + KeyspaceGroupId: utils.DefaultKeyspaceGroupID, + }, + Count: request.GetCount(), + DcLocation: request.GetDcLocation(), + } + + failpoint.Inject("tsoProxySendToTSOTimeout", func() { + // block until watchDeadline routine cancels the context. + <-ctx.Done() + }) + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + if err := forwardStream.Send(tsopbReq); err != nil { + return nil, err + } + + failpoint.Inject("tsoProxyRecvFromTSOTimeout", func() { + // block until watchDeadline routine cancels the context. + <-ctx.Done() + }) + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + return forwardStream.Recv() +} + +// forwardTSO forward the TSO requests to the TSO service. +func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { + var ( + server = &tsoServer{stream: stream} + forwardStream tsopb.TSO_TsoClient + forwardCtx context.Context + cancelForward context.CancelFunc + lastForwardedHost string + ) + defer func() { + s.concurrentTSOProxyStreamings.Add(-1) + if cancelForward != nil { + cancelForward() + } + }() + + maxConcurrentTSOProxyStreamings := int32(s.GetMaxConcurrentTSOProxyStreamings()) + if maxConcurrentTSOProxyStreamings >= 0 { + if newCount := s.concurrentTSOProxyStreamings.Add(1); newCount > maxConcurrentTSOProxyStreamings { + return errors.WithStack(ErrMaxCountTSOProxyRoutinesExceeded) + } + } + + tsDeadlineCh := make(chan *tsoutil.TSDeadline, 1) + go tsoutil.WatchTSDeadline(stream.Context(), tsDeadlineCh) + + for { + select { + case <-s.ctx.Done(): + return errors.WithStack(s.ctx.Err()) + case <-stream.Context().Done(): + return stream.Context().Err() + default: + } + + request, err := server.Recv(s.GetTSOProxyRecvFromClientTimeout()) + if err == io.EOF { + return nil + } + if err != nil { + return errors.WithStack(err) + } + if request.GetCount() == 0 { + err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive") + return status.Errorf(codes.Unknown, err.Error()) + } + + forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), utils.TSOServiceName) + if !ok || len(forwardedHost) == 0 { + return errors.WithStack(ErrNotFoundTSOAddr) + } + if forwardStream == nil || lastForwardedHost != forwardedHost { + if cancelForward != nil { + cancelForward() + } + + clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) + if err != nil { + return errors.WithStack(err) + } + forwardStream, forwardCtx, cancelForward, err = s.createTSOForwardStream(stream.Context(), clientConn) + if err != nil { + return errors.WithStack(err) + } + lastForwardedHost = forwardedHost + } + + tsopbResp, err := s.forwardTSORequestWithDeadLine(forwardCtx, cancelForward, forwardStream, request, tsDeadlineCh) + if err != nil { + return errors.WithStack(err) + } + + // The error types defined for tsopb and pdpb are different, so we need to convert them. + var pdpbErr *pdpb.Error + tsopbErr := tsopbResp.GetHeader().GetError() + if tsopbErr != nil { + if tsopbErr.Type == tsopb.ErrorType_OK { + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_OK, + Message: tsopbErr.GetMessage(), + } + } else { + // TODO: specify FORWARD FAILURE error type instead of UNKNOWN. + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_UNKNOWN, + Message: tsopbErr.GetMessage(), + } + } + } + + response := &pdpb.TsoResponse{ + Header: &pdpb.ResponseHeader{ + ClusterId: tsopbResp.GetHeader().GetClusterId(), + Error: pdpbErr, + }, + Count: tsopbResp.GetCount(), + Timestamp: tsopbResp.GetTimestamp(), + } + if err := server.Send(response); err != nil { + return errors.WithStack(err) + } + } +} + +func (s *GrpcServer) forwardTSORequestWithDeadLine( + forwardCtx context.Context, + cancelForward context.CancelFunc, + forwardStream tsopb.TSO_TsoClient, + request *pdpb.TsoRequest, + tsDeadlineCh chan<- *tsoutil.TSDeadline) (*tsopb.TsoResponse, error) { + done := make(chan struct{}) + dl := tsoutil.NewTSDeadline(tsoutil.DefaultTSOProxyTimeout, done, cancelForward) + select { + case tsDeadlineCh <- dl: + case <-forwardCtx.Done(): + return nil, forwardCtx.Err() + } + + start := time.Now() + resp, err := s.forwardTSORequest(forwardCtx, request, forwardStream) + close(done) + if err != nil { + if strings.Contains(err.Error(), errs.NotLeaderErr) { + s.tsoPrimaryWatcher.ForceLoad() + } + return nil, err + } + tsoProxyBatchSize.Observe(float64(request.GetCount())) + tsoProxyHandleDuration.Observe(time.Since(start).Seconds()) + return resp, nil +} + +func (s *GrpcServer) createTSOForwardStream(ctx context.Context, client *grpc.ClientConn) (tsopb.TSO_TsoClient, context.Context, context.CancelFunc, error) { + done := make(chan struct{}) + forwardCtx, cancelForward := context.WithCancel(ctx) + go grpcutil.CheckStream(forwardCtx, cancelForward, done) + forwardStream, err := tsopb.NewTSOClient(client).Tso(forwardCtx) + done <- struct{}{} + return forwardStream, forwardCtx, cancelForward, err +} + +func (s *GrpcServer) createRegionHeartbeatForwardStream(client *grpc.ClientConn) (pdpb.PD_RegionHeartbeatClient, context.CancelFunc, error) { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(s.ctx) + go grpcutil.CheckStream(ctx, cancel, done) + forwardStream, err := pdpb.NewPDClient(client).RegionHeartbeat(ctx) + done <- struct{}{} + return forwardStream, cancel, err +} + +func (s *GrpcServer) createRegionHeartbeatSchedulingStream(ctx context.Context, client *grpc.ClientConn) (schedulingpb.Scheduling_RegionHeartbeatClient, context.Context, context.CancelFunc, error) { + done := make(chan struct{}) + forwardCtx, cancelForward := context.WithCancel(ctx) + go grpcutil.CheckStream(forwardCtx, cancelForward, done) + forwardStream, err := schedulingpb.NewSchedulingClient(client).RegionHeartbeat(forwardCtx) + done <- struct{}{} + return forwardStream, forwardCtx, cancelForward, err +} + +func forwardRegionHeartbeatToScheduling(forwardStream schedulingpb.Scheduling_RegionHeartbeatClient, server *heartbeatServer, errCh chan error) { + defer logutil.LogPanic() + defer close(errCh) + for { + resp, err := forwardStream.Recv() + if err == io.EOF { + errCh <- errors.WithStack(err) + return + } + if err != nil { + errCh <- errors.WithStack(err) + return + } + // The error types defined for schedulingpb and pdpb are different, so we need to convert them. + var pdpbErr *pdpb.Error + schedulingpbErr := resp.GetHeader().GetError() + if schedulingpbErr != nil { + if schedulingpbErr.Type == schedulingpb.ErrorType_OK { + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_OK, + Message: schedulingpbErr.GetMessage(), + } + } else { + // TODO: specify FORWARD FAILURE error type instead of UNKNOWN. + pdpbErr = &pdpb.Error{ + Type: pdpb.ErrorType_UNKNOWN, + Message: schedulingpbErr.GetMessage(), + } + } + } + response := &pdpb.RegionHeartbeatResponse{ + Header: &pdpb.ResponseHeader{ + ClusterId: resp.GetHeader().GetClusterId(), + Error: pdpbErr, + }, + ChangePeer: resp.GetChangePeer(), + TransferLeader: resp.GetTransferLeader(), + RegionId: resp.GetRegionId(), + RegionEpoch: resp.GetRegionEpoch(), + TargetPeer: resp.GetTargetPeer(), + Merge: resp.GetMerge(), + SplitRegion: resp.GetSplitRegion(), + ChangePeerV2: resp.GetChangePeerV2(), + SwitchWitnesses: resp.GetSwitchWitnesses(), + } + + if err := server.Send(response); err != nil { + errCh <- errors.WithStack(err) + return + } + } +} + +func forwardRegionHeartbeatClientToServer(forwardStream pdpb.PD_RegionHeartbeatClient, server *heartbeatServer, errCh chan error) { + defer logutil.LogPanic() + defer close(errCh) + for { + resp, err := forwardStream.Recv() + if err != nil { + errCh <- errors.WithStack(err) + return + } + if err := server.Send(resp); err != nil { + errCh <- errors.WithStack(err) + return + } + } +} + +func forwardReportBucketClientToServer(forwardStream pdpb.PD_ReportBucketsClient, server *bucketHeartbeatServer, errCh chan error) { + defer logutil.LogPanic() + defer close(errCh) + for { + resp, err := forwardStream.CloseAndRecv() + if err != nil { + errCh <- errors.WithStack(err) + return + } + if err := server.Send(resp); err != nil { + errCh <- errors.WithStack(err) + return + } + } +} + +func (s *GrpcServer) createReportBucketsForwardStream(client *grpc.ClientConn) (pdpb.PD_ReportBucketsClient, context.CancelFunc, error) { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(s.ctx) + go grpcutil.CheckStream(ctx, cancel, done) + forwardStream, err := pdpb.NewPDClient(client).ReportBuckets(ctx) + done <- struct{}{} + return forwardStream, cancel, err +} + +func (s *GrpcServer) getDelegateClient(ctx context.Context, forwardedHost string) (*grpc.ClientConn, error) { + client, ok := s.clientConns.Load(forwardedHost) + if ok { + // Mostly, the connection is already established, and return it directly. + return client.(*grpc.ClientConn), nil + } + + tlsConfig, err := s.GetTLSConfig().ToTLSConfig() + if err != nil { + return nil, err + } + ctxTimeout, cancel := context.WithTimeout(ctx, defaultGRPCDialTimeout) + defer cancel() + newConn, err := grpcutil.GetClientConn(ctxTimeout, forwardedHost, tlsConfig) + if err != nil { + return nil, err + } + conn, loaded := s.clientConns.LoadOrStore(forwardedHost, newConn) + if !loaded { + // Successfully stored the connection we created. + return newConn, nil + } + // Loaded a connection created/stored by another goroutine, so close the one we created + // and return the one we loaded. + newConn.Close() + return conn.(*grpc.ClientConn), nil +} + +func (s *GrpcServer) getForwardedHost(ctx, streamCtx context.Context, serviceName ...string) (forwardedHost string, err error) { + if s.IsAPIServiceMode() { + var ok bool + if len(serviceName) == 0 { + return "", ErrNotFoundService + } + forwardedHost, ok = s.GetServicePrimaryAddr(ctx, serviceName[0]) + if !ok || len(forwardedHost) == 0 { + switch serviceName[0] { + case utils.TSOServiceName: + return "", ErrNotFoundTSOAddr + case utils.SchedulingServiceName: + return "", ErrNotFoundSchedulingAddr + } + } + } else if fh := grpcutil.GetForwardedHost(streamCtx); !s.isLocalRequest(fh) { + forwardedHost = fh + } + return forwardedHost, nil +} + +func (s *GrpcServer) isLocalRequest(forwardedHost string) bool { + failpoint.Inject("useForwardRequest", func() { + failpoint.Return(false) + }) + if forwardedHost == "" { + return true + } + memberAddrs := s.GetMember().Member().GetClientUrls() + for _, addr := range memberAddrs { + if addr == forwardedHost { + return true + } + } + return false +} + +func (s *GrpcServer) getGlobalTSO(ctx context.Context) (pdpb.Timestamp, error) { + if !s.IsAPIServiceMode() { + return s.tsoAllocatorManager.HandleRequest(ctx, tso.GlobalDCLocation, 1) + } + request := &tsopb.TsoRequest{ + Header: &tsopb.RequestHeader{ + ClusterId: s.clusterID, + KeyspaceId: utils.DefaultKeyspaceID, + KeyspaceGroupId: utils.DefaultKeyspaceGroupID, + }, + Count: 1, + } + var ( + forwardedHost string + forwardStream tsopb.TSO_TsoClient + ts *tsopb.TsoResponse + err error + ok bool + ) + handleStreamError := func(err error) (needRetry bool) { + if strings.Contains(err.Error(), errs.NotLeaderErr) { + s.tsoPrimaryWatcher.ForceLoad() + log.Warn("force to load tso primary address due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return true + } + if grpcutil.NeedRebuildConnection(err) { + s.tsoClientPool.Lock() + delete(s.tsoClientPool.clients, forwardedHost) + s.tsoClientPool.Unlock() + log.Warn("client connection removed due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return true + } + return false + } + for i := 0; i < maxRetryTimesRequestTSOServer; i++ { + if i > 0 { + time.Sleep(retryIntervalRequestTSOServer) + } + forwardedHost, ok = s.GetServicePrimaryAddr(ctx, utils.TSOServiceName) + if !ok || forwardedHost == "" { + return pdpb.Timestamp{}, ErrNotFoundTSOAddr + } + forwardStream, err = s.getTSOForwardStream(forwardedHost) + if err != nil { + return pdpb.Timestamp{}, err + } + err = forwardStream.Send(request) + if err != nil { + if needRetry := handleStreamError(err); needRetry { + continue + } + log.Error("send request to tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return pdpb.Timestamp{}, err + } + ts, err = forwardStream.Recv() + if err != nil { + if needRetry := handleStreamError(err); needRetry { + continue + } + log.Error("receive response from tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return pdpb.Timestamp{}, err + } + return *ts.GetTimestamp(), nil + } + log.Error("get global tso from tso primary server failed after retry", zap.Error(err), zap.String("tso-addr", forwardedHost)) + return pdpb.Timestamp{}, err +} + +func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (tsopb.TSO_TsoClient, error) { + s.tsoClientPool.RLock() + forwardStream, ok := s.tsoClientPool.clients[forwardedHost] + s.tsoClientPool.RUnlock() + if ok { + // This is the common case to return here + return forwardStream, nil + } + + s.tsoClientPool.Lock() + defer s.tsoClientPool.Unlock() + + // Double check after entering the critical section + forwardStream, ok = s.tsoClientPool.clients[forwardedHost] + if ok { + return forwardStream, nil + } + + // Now let's create the client connection and the forward stream + client, err := s.getDelegateClient(s.ctx, forwardedHost) + if err != nil { + return nil, err + } + done := make(chan struct{}) + ctx, cancel := context.WithCancel(s.ctx) + go grpcutil.CheckStream(ctx, cancel, done) + forwardStream, err = tsopb.NewTSOClient(client).Tso(ctx) + done <- struct{}{} + if err != nil { + return nil, err + } + s.tsoClientPool.clients[forwardedHost] = forwardStream + return forwardStream, nil +} diff --git a/server/grpc_service.go b/server/grpc_service.go index 34741d4da5b..b0384a7d629 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -70,6 +70,7 @@ var ( ErrSendHeartbeatTimeout = status.Errorf(codes.DeadlineExceeded, "send heartbeat timeout") ErrNotFoundTSOAddr = status.Errorf(codes.NotFound, "not found tso address") ErrNotFoundSchedulingAddr = status.Errorf(codes.NotFound, "not found scheduling address") + ErrNotFoundService = status.Errorf(codes.NotFound, "not found service") ErrForwardTSOTimeout = status.Errorf(codes.DeadlineExceeded, "forward tso request timeout") ErrMaxCountTSOProxyRoutinesExceeded = status.Errorf(codes.ResourceExhausted, "max count of concurrent tso proxy routines exceeded") ErrTSOProxyRecvFromClientTimeout = status.Errorf(codes.DeadlineExceeded, "tso proxy timeout when receiving from client; stream closed by server") @@ -83,9 +84,120 @@ type GrpcServer struct { concurrentTSOProxyStreamings atomic.Int32 } +// tsoServer wraps PD_TsoServer to ensure when any error +// occurs on Send() or Recv(), both endpoints will be closed. +type tsoServer struct { + stream pdpb.PD_TsoServer + closed int32 +} + +type pdpbTSORequest struct { + request *pdpb.TsoRequest + err error +} + +func (s *tsoServer) Send(m *pdpb.TsoResponse) error { + if atomic.LoadInt32(&s.closed) == 1 { + return io.EOF + } + done := make(chan error, 1) + go func() { + defer logutil.LogPanic() + failpoint.Inject("tsoProxyFailToSendToClient", func() { + done <- errors.New("injected error") + failpoint.Return() + }) + done <- s.stream.Send(m) + }() + timer := time.NewTimer(tsoutil.DefaultTSOProxyTimeout) + defer timer.Stop() + select { + case err := <-done: + if err != nil { + atomic.StoreInt32(&s.closed, 1) + } + return errors.WithStack(err) + case <-timer.C: + atomic.StoreInt32(&s.closed, 1) + return ErrForwardTSOTimeout + } +} + +func (s *tsoServer) Recv(timeout time.Duration) (*pdpb.TsoRequest, error) { + if atomic.LoadInt32(&s.closed) == 1 { + return nil, io.EOF + } + failpoint.Inject("tsoProxyRecvFromClientTimeout", func(val failpoint.Value) { + if customTimeoutInSeconds, ok := val.(int); ok { + timeout = time.Duration(customTimeoutInSeconds) * time.Second + } + }) + requestCh := make(chan *pdpbTSORequest, 1) + go func() { + defer logutil.LogPanic() + request, err := s.stream.Recv() + requestCh <- &pdpbTSORequest{request: request, err: err} + }() + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case req := <-requestCh: + if req.err != nil { + atomic.StoreInt32(&s.closed, 1) + return nil, errors.WithStack(req.err) + } + return req.request, nil + case <-timer.C: + atomic.StoreInt32(&s.closed, 1) + return nil, ErrTSOProxyRecvFromClientTimeout + } +} + +// heartbeatServer wraps PD_RegionHeartbeatServer to ensure when any error +// occurs on Send() or Recv(), both endpoints will be closed. +type heartbeatServer struct { + stream pdpb.PD_RegionHeartbeatServer + closed int32 +} + +func (s *heartbeatServer) Send(m core.RegionHeartbeatResponse) error { + if atomic.LoadInt32(&s.closed) == 1 { + return io.EOF + } + done := make(chan error, 1) + go func() { + defer logutil.LogPanic() + done <- s.stream.Send(m.(*pdpb.RegionHeartbeatResponse)) + }() + timer := time.NewTimer(heartbeatSendTimeout) + defer timer.Stop() + select { + case err := <-done: + if err != nil { + atomic.StoreInt32(&s.closed, 1) + } + return errors.WithStack(err) + case <-timer.C: + atomic.StoreInt32(&s.closed, 1) + return ErrSendHeartbeatTimeout + } +} + +func (s *heartbeatServer) Recv() (*pdpb.RegionHeartbeatRequest, error) { + if atomic.LoadInt32(&s.closed) == 1 { + return nil, io.EOF + } + req, err := s.stream.Recv() + if err != nil { + atomic.StoreInt32(&s.closed, 1) + return nil, errors.WithStack(err) + } + return req, nil +} + type schedulingClient struct { - client schedulingpb.SchedulingClient - lastPrimary string + client schedulingpb.SchedulingClient + primary string } func (s *schedulingClient) getClient() schedulingpb.SchedulingClient { @@ -99,7 +211,7 @@ func (s *schedulingClient) getPrimaryAddr() string { if s == nil { return "" } - return s.lastPrimary + return s.primary } type request interface { @@ -393,7 +505,7 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { return errors.WithStack(err) } - if forwardedHost, err := s.getForwardedHost(ctx, stream.Context()); err != nil { + if forwardedHost, err := s.getForwardedHost(ctx, stream.Context(), utils.TSOServiceName); err != nil { return err } else if len(forwardedHost) > 0 { clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) @@ -440,268 +552,6 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { } } -// forwardTSO forward the TSO requests to the TSO service. -func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { - var ( - server = &tsoServer{stream: stream} - forwardStream tsopb.TSO_TsoClient - forwardCtx context.Context - cancelForward context.CancelFunc - lastForwardedHost string - ) - defer func() { - s.concurrentTSOProxyStreamings.Add(-1) - if cancelForward != nil { - cancelForward() - } - }() - - maxConcurrentTSOProxyStreamings := int32(s.GetMaxConcurrentTSOProxyStreamings()) - if maxConcurrentTSOProxyStreamings >= 0 { - if newCount := s.concurrentTSOProxyStreamings.Add(1); newCount > maxConcurrentTSOProxyStreamings { - return errors.WithStack(ErrMaxCountTSOProxyRoutinesExceeded) - } - } - - tsDeadlineCh := make(chan *tsoutil.TSDeadline, 1) - go tsoutil.WatchTSDeadline(stream.Context(), tsDeadlineCh) - - for { - select { - case <-s.ctx.Done(): - return errors.WithStack(s.ctx.Err()) - case <-stream.Context().Done(): - return stream.Context().Err() - default: - } - - request, err := server.Recv(s.GetTSOProxyRecvFromClientTimeout()) - if err == io.EOF { - return nil - } - if err != nil { - return errors.WithStack(err) - } - if request.GetCount() == 0 { - err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive") - return status.Errorf(codes.Unknown, err.Error()) - } - - forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), utils.TSOServiceName) - if !ok || len(forwardedHost) == 0 { - return errors.WithStack(ErrNotFoundTSOAddr) - } - if forwardStream == nil || lastForwardedHost != forwardedHost { - if cancelForward != nil { - cancelForward() - } - - clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) - if err != nil { - return errors.WithStack(err) - } - forwardStream, forwardCtx, cancelForward, err = - s.createTSOForwardStream(stream.Context(), clientConn) - if err != nil { - return errors.WithStack(err) - } - lastForwardedHost = forwardedHost - } - - tsopbResp, err := s.forwardTSORequestWithDeadLine( - forwardCtx, cancelForward, forwardStream, request, tsDeadlineCh) - if err != nil { - return errors.WithStack(err) - } - - // The error types defined for tsopb and pdpb are different, so we need to convert them. - var pdpbErr *pdpb.Error - tsopbErr := tsopbResp.GetHeader().GetError() - if tsopbErr != nil { - if tsopbErr.Type == tsopb.ErrorType_OK { - pdpbErr = &pdpb.Error{ - Type: pdpb.ErrorType_OK, - Message: tsopbErr.GetMessage(), - } - } else { - // TODO: specify FORWARD FAILURE error type instead of UNKNOWN. - pdpbErr = &pdpb.Error{ - Type: pdpb.ErrorType_UNKNOWN, - Message: tsopbErr.GetMessage(), - } - } - } - - response := &pdpb.TsoResponse{ - Header: &pdpb.ResponseHeader{ - ClusterId: tsopbResp.GetHeader().GetClusterId(), - Error: pdpbErr, - }, - Count: tsopbResp.GetCount(), - Timestamp: tsopbResp.GetTimestamp(), - } - if err := server.Send(response); err != nil { - return errors.WithStack(err) - } - } -} - -func (s *GrpcServer) forwardTSORequestWithDeadLine( - forwardCtx context.Context, - cancelForward context.CancelFunc, - forwardStream tsopb.TSO_TsoClient, - request *pdpb.TsoRequest, - tsDeadlineCh chan<- *tsoutil.TSDeadline, -) (*tsopb.TsoResponse, error) { - done := make(chan struct{}) - dl := tsoutil.NewTSDeadline(tsoutil.DefaultTSOProxyTimeout, done, cancelForward) - select { - case tsDeadlineCh <- dl: - case <-forwardCtx.Done(): - return nil, forwardCtx.Err() - } - - start := time.Now() - resp, err := s.forwardTSORequest(forwardCtx, request, forwardStream) - close(done) - if err != nil { - if strings.Contains(err.Error(), errs.NotLeaderErr) { - s.tsoPrimaryWatcher.ForceLoad() - } - return nil, err - } - tsoProxyBatchSize.Observe(float64(request.GetCount())) - tsoProxyHandleDuration.Observe(time.Since(start).Seconds()) - return resp, nil -} - -func (s *GrpcServer) forwardTSORequest( - ctx context.Context, - request *pdpb.TsoRequest, - forwardStream tsopb.TSO_TsoClient, -) (*tsopb.TsoResponse, error) { - tsopbReq := &tsopb.TsoRequest{ - Header: &tsopb.RequestHeader{ - ClusterId: request.GetHeader().GetClusterId(), - SenderId: request.GetHeader().GetSenderId(), - KeyspaceId: utils.DefaultKeyspaceID, - KeyspaceGroupId: utils.DefaultKeyspaceGroupID, - }, - Count: request.GetCount(), - DcLocation: request.GetDcLocation(), - } - - failpoint.Inject("tsoProxySendToTSOTimeout", func() { - // block until watchDeadline routine cancels the context. - <-ctx.Done() - }) - - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - if err := forwardStream.Send(tsopbReq); err != nil { - return nil, err - } - - failpoint.Inject("tsoProxyRecvFromTSOTimeout", func() { - // block until watchDeadline routine cancels the context. - <-ctx.Done() - }) - - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - return forwardStream.Recv() -} - -// tsoServer wraps PD_TsoServer to ensure when any error -// occurs on Send() or Recv(), both endpoints will be closed. -type tsoServer struct { - stream pdpb.PD_TsoServer - closed int32 -} - -type pdpbTSORequest struct { - request *pdpb.TsoRequest - err error -} - -func (s *tsoServer) Send(m *pdpb.TsoResponse) error { - if atomic.LoadInt32(&s.closed) == 1 { - return io.EOF - } - done := make(chan error, 1) - go func() { - defer logutil.LogPanic() - failpoint.Inject("tsoProxyFailToSendToClient", func() { - done <- errors.New("injected error") - failpoint.Return() - }) - done <- s.stream.Send(m) - }() - timer := time.NewTimer(tsoutil.DefaultTSOProxyTimeout) - defer timer.Stop() - select { - case err := <-done: - if err != nil { - atomic.StoreInt32(&s.closed, 1) - } - return errors.WithStack(err) - case <-timer.C: - atomic.StoreInt32(&s.closed, 1) - return ErrForwardTSOTimeout - } -} - -func (s *tsoServer) Recv(timeout time.Duration) (*pdpb.TsoRequest, error) { - if atomic.LoadInt32(&s.closed) == 1 { - return nil, io.EOF - } - failpoint.Inject("tsoProxyRecvFromClientTimeout", func(val failpoint.Value) { - if customTimeoutInSeconds, ok := val.(int); ok { - timeout = time.Duration(customTimeoutInSeconds) * time.Second - } - }) - requestCh := make(chan *pdpbTSORequest, 1) - go func() { - defer logutil.LogPanic() - request, err := s.stream.Recv() - requestCh <- &pdpbTSORequest{request: request, err: err} - }() - timer := time.NewTimer(timeout) - defer timer.Stop() - select { - case req := <-requestCh: - if req.err != nil { - atomic.StoreInt32(&s.closed, 1) - return nil, errors.WithStack(req.err) - } - return req.request, nil - case <-timer.C: - atomic.StoreInt32(&s.closed, 1) - return nil, ErrTSOProxyRecvFromClientTimeout - } -} - -func (s *GrpcServer) getForwardedHost(ctx, streamCtx context.Context) (forwardedHost string, err error) { - if s.IsAPIServiceMode() { - var ok bool - forwardedHost, ok = s.GetServicePrimaryAddr(ctx, utils.TSOServiceName) - if !ok || len(forwardedHost) == 0 { - return "", ErrNotFoundTSOAddr - } - } else if fh := grpcutil.GetForwardedHost(streamCtx); !s.isLocalRequest(fh) { - forwardedHost = fh - } - return forwardedHost, nil -} - // Bootstrap implements gRPC PDServer. func (s *GrpcServer) Bootstrap(ctx context.Context, request *pdpb.BootstrapRequest) (*pdpb.BootstrapResponse, error) { fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { @@ -1004,7 +854,8 @@ func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHear storeHeartbeatHandleDuration.WithLabelValues(storeAddress, storeLabel).Observe(time.Since(start).Seconds()) if s.IsServiceIndependent(utils.SchedulingServiceName) { forwardCli, _ := s.updateSchedulingClient(ctx) - if forwardCli != nil { + cli := forwardCli.getClient() + if cli != nil { req := &schedulingpb.StoreHeartbeatRequest{ Header: &schedulingpb.RequestHeader{ ClusterId: request.GetHeader().GetClusterId(), @@ -1012,9 +863,10 @@ func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHear }, Stats: request.GetStats(), } - if _, err := forwardCli.StoreHeartbeat(ctx, req); err != nil { + if _, err := cli.StoreHeartbeat(ctx, req); err != nil { + log.Debug("forward store heartbeat failed", zap.Error(err)) // reset to let it be updated in the next request - s.schedulingClient.Store(&schedulingClient{}) + s.schedulingClient.CompareAndSwap(forwardCli, &schedulingClient{}) } } } @@ -1031,28 +883,38 @@ func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHear return resp, nil } -func (s *GrpcServer) updateSchedulingClient(ctx context.Context) (schedulingpb.SchedulingClient, error) { +// 1. forwardedHost is empty, return nil +// 2. forwardedHost is not empty and forwardedHost is equal to pre, return pre +// 3. the rest of cases, update forwardedHost and return new client +func (s *GrpcServer) updateSchedulingClient(ctx context.Context) (*schedulingClient, error) { forwardedHost, _ := s.GetServicePrimaryAddr(ctx, utils.SchedulingServiceName) + if forwardedHost == "" { + return nil, ErrNotFoundSchedulingAddr + } + pre := s.schedulingClient.Load() - // 1. forwardedHost is not empty and pre is empty, update the schedulingClient - // 2. forwardedHost is not empty and forwardedHost is not equal to pre, update the schedulingClient - // 3. forwardedHost is not empty and forwardedHost is equal to pre, return pre - // 4. forwardedHost is empty, return nil - if forwardedHost != "" && ((pre == nil) || (pre != nil && forwardedHost != pre.(*schedulingClient).getPrimaryAddr())) { - client, err := s.getDelegateClient(ctx, forwardedHost) - if err != nil { - log.Error("get delegate client failed", zap.Error(err)) - } - forwardCli := &schedulingClient{ - client: schedulingpb.NewSchedulingClient(client), - lastPrimary: forwardedHost, + if pre != nil && forwardedHost == pre.(*schedulingClient).getPrimaryAddr() { + return pre.(*schedulingClient), nil + } + + client, err := s.getDelegateClient(ctx, forwardedHost) + if err != nil { + log.Error("get delegate client failed", zap.Error(err)) + return nil, err + } + forwardCli := &schedulingClient{ + client: schedulingpb.NewSchedulingClient(client), + primary: forwardedHost, + } + swapped := s.schedulingClient.CompareAndSwap(pre, forwardCli) + if swapped { + oldForwardedHost := "" + if pre != nil { + oldForwardedHost = pre.(*schedulingClient).getPrimaryAddr() } - s.schedulingClient.Store(forwardCli) - return forwardCli.getClient(), nil - } else if forwardedHost != "" && (pre != nil && forwardedHost == pre.(*schedulingClient).getPrimaryAddr()) { - return pre.(*schedulingClient).getClient(), nil + log.Info("update scheduling client", zap.String("old-forwarded-host", oldForwardedHost), zap.String("new-forwarded-host", forwardedHost)) } - return nil, ErrNotFoundSchedulingAddr + return forwardCli, nil } // bucketHeartbeatServer wraps PD_ReportBucketsServer to ensure when any error @@ -1097,48 +959,6 @@ func (b *bucketHeartbeatServer) Recv() (*pdpb.ReportBucketsRequest, error) { return req, nil } -// heartbeatServer wraps PD_RegionHeartbeatServer to ensure when any error -// occurs on Send() or Recv(), both endpoints will be closed. -type heartbeatServer struct { - stream pdpb.PD_RegionHeartbeatServer - closed int32 -} - -func (s *heartbeatServer) Send(m core.RegionHeartbeatResponse) error { - if atomic.LoadInt32(&s.closed) == 1 { - return io.EOF - } - done := make(chan error, 1) - go func() { - defer logutil.LogPanic() - done <- s.stream.Send(m.(*pdpb.RegionHeartbeatResponse)) - }() - timer := time.NewTimer(heartbeatSendTimeout) - defer timer.Stop() - select { - case err := <-done: - if err != nil { - atomic.StoreInt32(&s.closed, 1) - } - return errors.WithStack(err) - case <-timer.C: - atomic.StoreInt32(&s.closed, 1) - return ErrSendHeartbeatTimeout - } -} - -func (s *heartbeatServer) Recv() (*pdpb.RegionHeartbeatRequest, error) { - if atomic.LoadInt32(&s.closed) == 1 { - return nil, io.EOF - } - req, err := s.stream.Recv() - if err != nil { - atomic.StoreInt32(&s.closed, 1) - return nil, errors.WithStack(err) - } - return req, nil -} - // ReportBuckets implements gRPC PDServer func (s *GrpcServer) ReportBuckets(stream pdpb.PD_ReportBucketsServer) error { var ( @@ -1236,16 +1056,16 @@ func (s *GrpcServer) ReportBuckets(stream pdpb.PD_ReportBucketsServer) error { // RegionHeartbeat implements gRPC PDServer. func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error { var ( - server = &heartbeatServer{stream: stream} - flowRoundOption = core.WithFlowRoundByDigit(s.persistOptions.GetPDServerConfig().FlowRoundByDigit) - forwardStream pdpb.PD_RegionHeartbeatClient - cancel context.CancelFunc - lastForwardedHost string - lastBind time.Time - errCh chan error - schedulingStream schedulingpb.Scheduling_RegionHeartbeatClient - cancel1 context.CancelFunc - lastPrimaryAddr string + server = &heartbeatServer{stream: stream} + flowRoundOption = core.WithFlowRoundByDigit(s.persistOptions.GetPDServerConfig().FlowRoundByDigit) + cancel context.CancelFunc + lastBind time.Time + errCh chan error + forwardStream pdpb.PD_RegionHeartbeatClient + lastForwardedHost string + forwardErrCh chan error + forwardSchedulingStream schedulingpb.Scheduling_RegionHeartbeatClient + lastForwardedSchedulingHost string ) defer func() { // cancel the forward stream @@ -1262,8 +1082,10 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error if err != nil { return errors.WithStack(err) } - forwardedHost := grpcutil.GetForwardedHost(stream.Context()) + failpoint.Inject("grpcClientClosed", func() { + forwardedHost = s.GetMember().Member().GetClientUrls()[0] + }) if !s.isLocalRequest(forwardedHost) { if forwardStream == nil || lastForwardedHost != forwardedHost { if cancel != nil { @@ -1274,7 +1096,7 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error return err } log.Info("create region heartbeat forward stream", zap.String("forwarded-host", forwardedHost)) - forwardStream, cancel, err = s.createHeartbeatForwardStream(client) + forwardStream, cancel, err = s.createRegionHeartbeatForwardStream(client) if err != nil { return err } @@ -1360,56 +1182,83 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error continue } + regionHeartbeatHandleDuration.WithLabelValues(storeAddress, storeLabel).Observe(time.Since(start).Seconds()) + regionHeartbeatCounter.WithLabelValues(storeAddress, storeLabel, "report", "ok").Inc() + if s.IsServiceIndependent(utils.SchedulingServiceName) { - ctx := stream.Context() - primaryAddr, _ := s.GetServicePrimaryAddr(ctx, utils.SchedulingServiceName) - if schedulingStream == nil || lastPrimaryAddr != primaryAddr { - if cancel1 != nil { - cancel1() + if forwardErrCh != nil { + select { + case err, ok := <-forwardErrCh: + if ok { + if cancel != nil { + cancel() + } + forwardSchedulingStream = nil + log.Error("meet error and need to re-establish the stream", zap.Error(err)) + } + default: } - client, err := s.getDelegateClient(ctx, primaryAddr) + } + forwardedSchedulingHost, ok := s.GetServicePrimaryAddr(stream.Context(), utils.SchedulingServiceName) + if !ok || len(forwardedSchedulingHost) == 0 { + log.Debug("failed to find scheduling service primary address") + if cancel != nil { + cancel() + } + continue + } + if forwardSchedulingStream == nil || lastForwardedSchedulingHost != forwardedSchedulingHost { + if cancel != nil { + cancel() + } + client, err := s.getDelegateClient(s.ctx, forwardedSchedulingHost) if err != nil { - log.Error("get delegate client failed", zap.Error(err)) + log.Error("failed to get client", zap.Error(err)) + continue } - - log.Info("create region heartbeat forward stream", zap.String("forwarded-host", primaryAddr)) - schedulingStream, cancel1, err = s.createSchedulingStream(client) + log.Info("create scheduling forwarding stream", zap.String("forwarded-host", forwardedSchedulingHost)) + forwardSchedulingStream, _, cancel, err = s.createRegionHeartbeatSchedulingStream(stream.Context(), client) if err != nil { - log.Error("create region heartbeat forward stream failed", zap.Error(err)) - } else { - lastPrimaryAddr = primaryAddr - errCh = make(chan error, 1) - go forwardSchedulingToServer(schedulingStream, server, errCh) + log.Error("failed to create stream", zap.Error(err)) + continue } + lastForwardedSchedulingHost = forwardedSchedulingHost + forwardErrCh = make(chan error, 1) + go forwardRegionHeartbeatToScheduling(forwardSchedulingStream, server, forwardErrCh) } - if schedulingStream != nil { - req := &schedulingpb.RegionHeartbeatRequest{ - Header: &schedulingpb.RequestHeader{ - ClusterId: request.GetHeader().GetClusterId(), - SenderId: request.GetHeader().GetSenderId(), - }, - Region: request.GetRegion(), - Leader: request.GetLeader(), - DownPeers: request.GetDownPeers(), - PendingPeers: request.GetPendingPeers(), - BytesWritten: request.GetBytesWritten(), - BytesRead: request.GetBytesRead(), - KeysWritten: request.GetKeysWritten(), - KeysRead: request.GetKeysRead(), - ApproximateSize: request.GetApproximateSize(), - ApproximateKeys: request.GetApproximateKeys(), - Interval: request.GetInterval(), - Term: request.GetTerm(), - QueryStats: request.GetQueryStats(), - } - if err := schedulingStream.Send(req); err != nil { - log.Error("forward region heartbeat failed", zap.Error(err)) + schedulingpbReq := &schedulingpb.RegionHeartbeatRequest{ + Header: &schedulingpb.RequestHeader{ + ClusterId: request.GetHeader().GetClusterId(), + SenderId: request.GetHeader().GetSenderId(), + }, + Region: request.GetRegion(), + Leader: request.GetLeader(), + DownPeers: request.GetDownPeers(), + PendingPeers: request.GetPendingPeers(), + BytesWritten: request.GetBytesWritten(), + BytesRead: request.GetBytesRead(), + KeysWritten: request.GetKeysWritten(), + KeysRead: request.GetKeysRead(), + ApproximateSize: request.GetApproximateSize(), + ApproximateKeys: request.GetApproximateKeys(), + Interval: request.GetInterval(), + Term: request.GetTerm(), + QueryStats: request.GetQueryStats(), + } + if err := forwardSchedulingStream.Send(schedulingpbReq); err != nil { + forwardSchedulingStream = nil + log.Error("failed to send request to scheduling service", zap.Error(err)) + } + + select { + case err, ok := <-forwardErrCh: + if ok { + forwardSchedulingStream = nil + log.Error("failed to send response", zap.Error(err)) } + default: } } - - regionHeartbeatHandleDuration.WithLabelValues(storeAddress, storeLabel).Observe(time.Since(start).Seconds()) - regionHeartbeatCounter.WithLabelValues(storeAddress, storeLabel, "report", "ok").Inc() } } @@ -1639,7 +1488,8 @@ func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSp Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), }, nil } - if forwardCli != nil { + cli := forwardCli.getClient() + if cli != nil { req := &schedulingpb.AskBatchSplitRequest{ Header: &schedulingpb.RequestHeader{ ClusterId: request.GetHeader().GetClusterId(), @@ -1648,10 +1498,10 @@ func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSp Region: request.GetRegion(), SplitCount: request.GetSplitCount(), } - resp, err := s.schedulingClient.Load().(*schedulingClient).getClient().AskBatchSplit(ctx, req) + resp, err := cli.AskBatchSplit(ctx, req) if err != nil { // reset to let it be updated in the next request - s.schedulingClient.Store(&schedulingClient{}) + s.schedulingClient.CompareAndSwap(forwardCli, &schedulingClient{}) return s.convertAskSplitResponse(resp), err } return s.convertAskSplitResponse(resp), nil @@ -1812,7 +1662,8 @@ func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterReg Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), }, nil } - if forwardCli != nil { + cli := forwardCli.getClient() + if cli != nil { var regionsID []uint64 // nolint if request.GetRegionId() != 0 { @@ -1836,10 +1687,10 @@ func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterReg RetryLimit: request.GetRetryLimit(), SkipStoreLimit: request.GetSkipStoreLimit(), } - resp, err := forwardCli.ScatterRegions(ctx, req) + resp, err := cli.ScatterRegions(ctx, req) if err != nil { // reset to let it be updated in the next request - s.schedulingClient.Store(&schedulingClient{}) + s.schedulingClient.CompareAndSwap(forwardCli, &schedulingClient{}) return s.convertScatterResponse(resp), err } return s.convertScatterResponse(resp), nil @@ -2035,7 +1886,8 @@ func (s *GrpcServer) GetOperator(ctx context.Context, request *pdpb.GetOperatorR Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), }, nil } - if forwardCli != nil { + cli := forwardCli.getClient() + if cli != nil { req := &schedulingpb.GetOperatorRequest{ Header: &schedulingpb.RequestHeader{ ClusterId: request.GetHeader().GetClusterId(), @@ -2043,10 +1895,10 @@ func (s *GrpcServer) GetOperator(ctx context.Context, request *pdpb.GetOperatorR }, RegionId: request.GetRegionId(), } - resp, err := forwardCli.GetOperator(ctx, req) + resp, err := cli.GetOperator(ctx, req) if err != nil { // reset to let it be updated in the next request - s.schedulingClient.Store(&schedulingClient{}) + s.schedulingClient.CompareAndSwap(forwardCli, &schedulingClient{}) return s.convertOperatorResponse(resp), err } return s.convertOperatorResponse(resp), nil @@ -2307,7 +2159,8 @@ func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegion Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), }, nil } - if forwardCli != nil { + cli := forwardCli.getClient() + if cli != nil { req := &schedulingpb.SplitRegionsRequest{ Header: &schedulingpb.RequestHeader{ ClusterId: request.GetHeader().GetClusterId(), @@ -2316,10 +2169,10 @@ func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegion SplitKeys: request.GetSplitKeys(), RetryLimit: request.GetRetryLimit(), } - resp, err := forwardCli.SplitRegions(ctx, req) + resp, err := cli.SplitRegions(ctx, req) if err != nil { // reset to let it be updated in the next request - s.schedulingClient.Store(&schedulingClient{}) + s.schedulingClient.CompareAndSwap(forwardCli, &schedulingClient{}) return s.convertSplitResponse(resp), err } return s.convertSplitResponse(resp), nil @@ -2451,258 +2304,6 @@ func (s *GrpcServer) validateInternalRequest(header *pdpb.RequestHeader, onlyAll return nil } -func (s *GrpcServer) getDelegateClient(ctx context.Context, forwardedHost string) (*grpc.ClientConn, error) { - client, ok := s.clientConns.Load(forwardedHost) - if ok { - // Mostly, the connection is already established, and return it directly. - return client.(*grpc.ClientConn), nil - } - - tlsConfig, err := s.GetTLSConfig().ToTLSConfig() - if err != nil { - return nil, err - } - ctxTimeout, cancel := context.WithTimeout(ctx, defaultGRPCDialTimeout) - defer cancel() - newConn, err := grpcutil.GetClientConn(ctxTimeout, forwardedHost, tlsConfig) - if err != nil { - return nil, err - } - conn, loaded := s.clientConns.LoadOrStore(forwardedHost, newConn) - if !loaded { - // Successfully stored the connection we created. - return newConn, nil - } - // Loaded a connection created/stored by another goroutine, so close the one we created - // and return the one we loaded. - newConn.Close() - return conn.(*grpc.ClientConn), nil -} - -func (s *GrpcServer) isLocalRequest(forwardedHost string) bool { - failpoint.Inject("useForwardRequest", func() { - failpoint.Return(false) - }) - if forwardedHost == "" { - return true - } - memberAddrs := s.GetMember().Member().GetClientUrls() - for _, addr := range memberAddrs { - if addr == forwardedHost { - return true - } - } - return false -} - -func (s *GrpcServer) createHeartbeatForwardStream(client *grpc.ClientConn) (pdpb.PD_RegionHeartbeatClient, context.CancelFunc, error) { - done := make(chan struct{}) - ctx, cancel := context.WithCancel(s.ctx) - go grpcutil.CheckStream(ctx, cancel, done) - forwardStream, err := pdpb.NewPDClient(client).RegionHeartbeat(ctx) - done <- struct{}{} - return forwardStream, cancel, err -} - -func forwardRegionHeartbeatClientToServer(forwardStream pdpb.PD_RegionHeartbeatClient, server *heartbeatServer, errCh chan error) { - defer logutil.LogPanic() - defer close(errCh) - for { - resp, err := forwardStream.Recv() - if err != nil { - errCh <- errors.WithStack(err) - return - } - if err := server.Send(resp); err != nil { - errCh <- errors.WithStack(err) - return - } - } -} - -func (s *GrpcServer) createSchedulingStream(client *grpc.ClientConn) (schedulingpb.Scheduling_RegionHeartbeatClient, context.CancelFunc, error) { - if client == nil { - return nil, nil, errors.New("connection is not set") - } - done := make(chan struct{}) - ctx, cancel := context.WithCancel(s.ctx) - go grpcutil.CheckStream(ctx, cancel, done) - forwardStream, err := schedulingpb.NewSchedulingClient(client).RegionHeartbeat(ctx) - done <- struct{}{} - return forwardStream, cancel, err -} - -func forwardSchedulingToServer(forwardStream schedulingpb.Scheduling_RegionHeartbeatClient, server *heartbeatServer, errCh chan error) { - defer logutil.LogPanic() - defer close(errCh) - for { - resp, err := forwardStream.Recv() - if err != nil { - errCh <- errors.WithStack(err) - return - } - response := &pdpb.RegionHeartbeatResponse{ - Header: &pdpb.ResponseHeader{ - ClusterId: resp.GetHeader().GetClusterId(), - // ignore error here - }, - ChangePeer: resp.GetChangePeer(), - TransferLeader: resp.GetTransferLeader(), - RegionId: resp.GetRegionId(), - RegionEpoch: resp.GetRegionEpoch(), - TargetPeer: resp.GetTargetPeer(), - Merge: resp.GetMerge(), - SplitRegion: resp.GetSplitRegion(), - ChangePeerV2: resp.GetChangePeerV2(), - SwitchWitnesses: resp.GetSwitchWitnesses(), - } - - if err := server.Send(response); err != nil { - errCh <- errors.WithStack(err) - return - } - } -} - -func (s *GrpcServer) createTSOForwardStream( - ctx context.Context, client *grpc.ClientConn, -) (tsopb.TSO_TsoClient, context.Context, context.CancelFunc, error) { - done := make(chan struct{}) - forwardCtx, cancelForward := context.WithCancel(ctx) - go grpcutil.CheckStream(forwardCtx, cancelForward, done) - forwardStream, err := tsopb.NewTSOClient(client).Tso(forwardCtx) - done <- struct{}{} - return forwardStream, forwardCtx, cancelForward, err -} - -func (s *GrpcServer) createReportBucketsForwardStream(client *grpc.ClientConn) (pdpb.PD_ReportBucketsClient, context.CancelFunc, error) { - done := make(chan struct{}) - ctx, cancel := context.WithCancel(s.ctx) - go grpcutil.CheckStream(ctx, cancel, done) - forwardStream, err := pdpb.NewPDClient(client).ReportBuckets(ctx) - done <- struct{}{} - return forwardStream, cancel, err -} - -func forwardReportBucketClientToServer(forwardStream pdpb.PD_ReportBucketsClient, server *bucketHeartbeatServer, errCh chan error) { - defer logutil.LogPanic() - defer close(errCh) - for { - resp, err := forwardStream.CloseAndRecv() - if err != nil { - errCh <- errors.WithStack(err) - return - } - if err := server.Send(resp); err != nil { - errCh <- errors.WithStack(err) - return - } - } -} - -func (s *GrpcServer) getGlobalTSO(ctx context.Context) (pdpb.Timestamp, error) { - if !s.IsAPIServiceMode() { - return s.tsoAllocatorManager.HandleRequest(ctx, tso.GlobalDCLocation, 1) - } - request := &tsopb.TsoRequest{ - Header: &tsopb.RequestHeader{ - ClusterId: s.clusterID, - KeyspaceId: utils.DefaultKeyspaceID, - KeyspaceGroupId: utils.DefaultKeyspaceGroupID, - }, - Count: 1, - } - var ( - forwardedHost string - forwardStream tsopb.TSO_TsoClient - ts *tsopb.TsoResponse - err error - ok bool - ) - handleStreamError := func(err error) (needRetry bool) { - if strings.Contains(err.Error(), errs.NotLeaderErr) { - s.tsoPrimaryWatcher.ForceLoad() - log.Warn("force to load tso primary address due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) - return true - } - if grpcutil.NeedRebuildConnection(err) { - s.tsoClientPool.Lock() - delete(s.tsoClientPool.clients, forwardedHost) - s.tsoClientPool.Unlock() - log.Warn("client connection removed due to error", zap.Error(err), zap.String("tso-addr", forwardedHost)) - return true - } - return false - } - for i := 0; i < maxRetryTimesRequestTSOServer; i++ { - if i > 0 { - time.Sleep(retryIntervalRequestTSOServer) - } - forwardedHost, ok = s.GetServicePrimaryAddr(ctx, utils.TSOServiceName) - if !ok || forwardedHost == "" { - return pdpb.Timestamp{}, ErrNotFoundTSOAddr - } - forwardStream, err = s.getTSOForwardStream(forwardedHost) - if err != nil { - return pdpb.Timestamp{}, err - } - err = forwardStream.Send(request) - if err != nil { - if needRetry := handleStreamError(err); needRetry { - continue - } - log.Error("send request to tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) - return pdpb.Timestamp{}, err - } - ts, err = forwardStream.Recv() - if err != nil { - if needRetry := handleStreamError(err); needRetry { - continue - } - log.Error("receive response from tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) - return pdpb.Timestamp{}, err - } - return *ts.GetTimestamp(), nil - } - log.Error("get global tso from tso primary server failed after retry", zap.Error(err), zap.String("tso-addr", forwardedHost)) - return pdpb.Timestamp{}, err -} - -func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (tsopb.TSO_TsoClient, error) { - s.tsoClientPool.RLock() - forwardStream, ok := s.tsoClientPool.clients[forwardedHost] - s.tsoClientPool.RUnlock() - if ok { - // This is the common case to return here - return forwardStream, nil - } - - s.tsoClientPool.Lock() - defer s.tsoClientPool.Unlock() - - // Double check after entering the critical section - forwardStream, ok = s.tsoClientPool.clients[forwardedHost] - if ok { - return forwardStream, nil - } - - // Now let's create the client connection and the forward stream - client, err := s.getDelegateClient(s.ctx, forwardedHost) - if err != nil { - return nil, err - } - done := make(chan struct{}) - ctx, cancel := context.WithCancel(s.ctx) - go grpcutil.CheckStream(ctx, cancel, done) - forwardStream, err = tsopb.NewTSOClient(client).Tso(ctx) - done <- struct{}{} - if err != nil { - return nil, err - } - s.tsoClientPool.clients[forwardedHost] = forwardStream - return forwardStream, nil -} - // for CDC compatibility, we need to initialize config path to `globalConfigPath` const globalConfigPath = "/global/config/"