From 46fe516fb2acf628aba79467506854c92a4595cd Mon Sep 17 00:00:00 2001 From: JmPotato Date: Thu, 19 Dec 2024 18:18:49 +0800 Subject: [PATCH 1/4] Implement the query region gRPC client Signed-off-by: JmPotato --- client/client.go | 25 +- client/clients/router/client.go | 433 +++++++++++++++++++++++ client/clients/router/request.go | 115 ++++++ client/errs/errno.go | 1 + client/go.mod | 2 +- client/go.sum | 4 +- client/inner_client.go | 3 + client/pkg/utils/timerutil/util.go | 32 ++ tests/integrations/client/client_test.go | 1 - 9 files changed, 591 insertions(+), 25 deletions(-) create mode 100644 client/clients/router/request.go create mode 100644 client/pkg/utils/timerutil/util.go diff --git a/client/client.go b/client/client.go index fa0a1473ba7..f1c0a49ff8a 100644 --- a/client/client.go +++ b/client/client.go @@ -570,23 +570,6 @@ func (c *client) GetMinTS(ctx context.Context) (physical int64, logical int64, e return minTS.Physical, minTS.Logical, nil } -func handleRegionResponse(res *pdpb.GetRegionResponse) *router.Region { - if res.Region == nil { - return nil - } - - r := &router.Region{ - Meta: res.Region, - Leader: res.Leader, - PendingPeers: res.PendingPeers, - Buckets: res.Buckets, - } - for _, s := range res.DownPeers { - r.DownPeers = append(r.DownPeers, s.Peer) - } - return r -} - // GetRegionFromMember implements the RPCClient interface. func (c *client) GetRegionFromMember(ctx context.Context, key []byte, memberURLs []string, _ ...opt.GetRegionOption) (*router.Region, error) { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { @@ -623,7 +606,7 @@ func (c *client) GetRegionFromMember(ctx context.Context, key []byte, memberURLs errorMsg := fmt.Sprintf("[pd] can't get region info from member URLs: %+v", memberURLs) return nil, errors.WithStack(errors.New(errorMsg)) } - return handleRegionResponse(resp), nil + return router.ConvertToRegion(resp), nil } // GetRegion implements the RPCClient interface. @@ -663,7 +646,7 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio if err = c.respForErr(metrics.CmdFailedDurationGetRegion, start, err, resp.GetHeader()); err != nil { return nil, err } - return handleRegionResponse(resp), nil + return router.ConvertToRegion(resp), nil } // GetPrevRegion implements the RPCClient interface. @@ -703,7 +686,7 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR if err = c.respForErr(metrics.CmdFailedDurationGetPrevRegion, start, err, resp.GetHeader()); err != nil { return nil, err } - return handleRegionResponse(resp), nil + return router.ConvertToRegion(resp), nil } // GetRegionByID implements the RPCClient interface. @@ -744,7 +727,7 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt if err = c.respForErr(metrics.CmdFailedDurationGetRegionByID, start, err, resp.GetHeader()); err != nil { return nil, err } - return handleRegionResponse(resp), nil + return router.ConvertToRegion(resp), nil } // ScanRegions implements the RPCClient interface. diff --git a/client/clients/router/client.go b/client/clients/router/client.go index 48cebfa950e..8d0cfd64d2d 100644 --- a/client/clients/router/client.go +++ b/client/clients/router/client.go @@ -18,12 +18,31 @@ import ( "context" "encoding/hex" "net/url" + "runtime/trace" + "sync" + "sync/atomic" + "time" + + "github.com/opentracing/opentracing-go" + "go.uber.org/zap" + "google.golang.org/grpc" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" + "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/opt" + "github.com/tikv/pd/client/pkg/batch" + cctx "github.com/tikv/pd/client/pkg/connectionctx" + "github.com/tikv/pd/client/pkg/retry" + "github.com/tikv/pd/client/pkg/utils/timerutil" + sd "github.com/tikv/pd/client/servicediscovery" ) +// defaultMaxRouterRequestBatchSize is the default max size of the router request batch. +const defaultMaxRouterRequestBatchSize = 10000 + // Region contains information of a region's meta and its peers. type Region struct { Meta *metapb.Region @@ -33,6 +52,33 @@ type Region struct { Buckets *metapb.Buckets } +type regionResponse interface { + GetRegion() *metapb.Region + GetLeader() *metapb.Peer + GetDownPeers() []*pdpb.PeerStats + GetPendingPeers() []*metapb.Peer + GetBuckets() *metapb.Buckets +} + +// ConvertToRegion converts the region response to the region. +func ConvertToRegion(res regionResponse) *Region { + region := res.GetRegion() + if region == nil { + return nil + } + + r := &Region{ + Meta: region, + Leader: res.GetLeader(), + PendingPeers: res.GetPendingPeers(), + Buckets: res.GetBuckets(), + } + for _, s := range res.GetDownPeers() { + r.DownPeers = append(r.DownPeers, s.Peer) + } + return r +} + // KeyRange defines a range of keys in bytes. type KeyRange struct { StartKey []byte @@ -92,3 +138,390 @@ type Client interface { // The returned regions are flattened, even there are key ranges located in the same region, only one region will be returned. BatchScanRegions(ctx context.Context, keyRanges []KeyRange, limit int, opts ...opt.GetRegionOption) ([]*Region, error) } + +// Cli is the implementation of the router client. +type Cli struct { + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + option *opt.Option + + svcDiscovery sd.ServiceDiscovery + // leaderURL is the URL of the router leader. + leaderURL atomic.Value + // conCtxMgr is used to store the context of the router stream connection(s). + conCtxMgr *cctx.Manager[pdpb.PD_QueryRegionClient] + // updateConnectionCh is used to trigger the connection update actively. + updateConnectionCh chan struct{} + // bo is the backoffer for the router client. + bo *retry.Backoffer + + reqPool *sync.Pool + requestCh chan *Request + batchController *batch.Controller[*Request] +} + +// NewClient returns a new router client. +func NewClient( + ctx context.Context, + svcDiscovery sd.ServiceDiscovery, + option *opt.Option, +) *Cli { + ctx, cancel := context.WithCancel(ctx) + c := &Cli{ + ctx: ctx, + cancel: cancel, + svcDiscovery: svcDiscovery, + option: option, + conCtxMgr: cctx.NewManager[pdpb.PD_QueryRegionClient](), + updateConnectionCh: make(chan struct{}, 1), + bo: retry.InitialBackoffer( + sd.UpdateMemberBackOffBaseTime, sd.UpdateMemberTimeout, sd.UpdateMemberBackOffBaseTime), + reqPool: &sync.Pool{ + New: func() any { + return &Request{ + done: make(chan error, 1), + } + }, + }, + requestCh: make(chan *Request, defaultMaxRouterRequestBatchSize*2), + batchController: batch.NewController(defaultMaxRouterRequestBatchSize, requestFinisher(nil), nil), + } + c.leaderURL.Store(svcDiscovery.GetServingURL()) + c.svcDiscovery.ExecAndAddLeaderSwitchedCallback(c.updateLeaderURL) + + c.wg.Add(2) + go c.connectionDaemon() + go c.dispatcher() + + return c +} + +func (c *Cli) newRequest(ctx context.Context) *Request { + req := c.reqPool.Get().(*Request) + req.requestCtx = ctx + req.clientCtx = c.ctx + req.pool = c.reqPool + + return req +} + +func requestFinisher(resp *pdpb.QueryRegionResponse) batch.FinisherFunc[*Request] { + var keyIdx, prevKeyIdx int + return func(_ int, req *Request, err error) { + requestCtx := req.requestCtx + defer trace.StartRegion(requestCtx, "pdclient.regionReqDone").End() + + if err != nil { + req.tryDone(err) + return + } + + var id uint64 + if req.key != nil { + id = resp.KeyIdMap[keyIdx] + keyIdx++ + } else if req.prevKey != nil { + id = resp.PrevKeyIdMap[prevKeyIdx] + prevKeyIdx++ + } else if req.id != 0 { + id = req.id + } + region, ok := resp.RegionsById[id] + if !ok { + err = errs.ErrClientRegionNotFound.FastGenByArgs(id) + } else { + req.region = ConvertToRegion(region) + } + req.tryDone(err) + } +} + +func (c *Cli) cancelCollectedRequests(err error) { + c.batchController.FinishCollectedRequests(requestFinisher(nil), err) +} + +func (c *Cli) doneCollectedRequests(resp *pdpb.QueryRegionResponse) { + c.batchController.FinishCollectedRequests(requestFinisher(resp), nil) +} + +// Close closes the router client. +func (c *Cli) Close() { + if c == nil { + return + } + log.Info("[router] closing router client") + + c.cancel() + c.wg.Wait() + + log.Info("[router] router client is closed") +} + +func (c *Cli) getLeaderURL() string { + url := c.leaderURL.Load() + if url == nil { + return "" + } + return url.(string) +} + +func (c *Cli) updateLeaderURL(url string) error { + oldURL := c.getLeaderURL() + if !c.leaderURL.CompareAndSwap(oldURL, url) { + return nil + } + c.scheduleUpdateConnection() + + log.Info("[router] switch the router leader serving url", + zap.String("old-url", oldURL), zap.String("new-url", url)) + return nil +} + +// getLeaderClientConn returns the leader gRPC client connection. +func (c *Cli) getLeaderClientConn() (*grpc.ClientConn, string) { + url := c.getLeaderURL() + if len(url) == 0 { + c.svcDiscovery.ScheduleCheckMemberChanged() + return nil, "" + } + cc, ok := c.svcDiscovery.GetClientConns().Load(url) + if !ok { + return nil, url + } + return cc.(*grpc.ClientConn), url +} + +// scheduleUpdateConnection is used to schedule an update to the connection(s). +func (c *Cli) scheduleUpdateConnection() { + select { + case c.updateConnectionCh <- struct{}{}: + default: + } +} + +// connectionDaemon is used to update the router leader/primary/backup connection(s) in background. +// It aims to provide a seamless connection updating for the router client to keep providing the +// router service without interruption. +func (c *Cli) connectionDaemon() { + defer c.wg.Done() + updaterCtx, updaterCancel := context.WithCancel(c.ctx) + defer updaterCancel() + updateTicker := time.NewTicker(sd.MemberUpdateInterval) + defer updateTicker.Stop() + + log.Info("[router] connection daemon is started") + for { + c.updateConnection(updaterCtx) + select { + case <-updaterCtx.Done(): + log.Info("[router] connection daemon is exiting") + return + case <-updateTicker.C: + case <-c.updateConnectionCh: + } + } +} + +// updateConnection is used to get the leader client connection and update the connection context if it does not exist before. +func (c *Cli) updateConnection(ctx context.Context) { + cc, url := c.getLeaderClientConn() + if cc == nil || len(url) == 0 { + log.Warn("[router] got an invalid leader client connection", zap.String("url", url)) + return + } + if c.conCtxMgr.Exist(url) { + log.Debug("[router] the router leader remains unchanged", zap.String("url", url)) + return + } + stream, err := pdpb.NewPDClient(cc).QueryRegion(ctx) + if err != nil { + log.Error("[router] failed to create the router stream connection", errs.ZapError(err)) + } + c.conCtxMgr.Store(ctx, url, stream) + // TODO: support the forwarding mechanism for the router client. + // TODO: support sending the router requests to the follower nodes. +} + +func (c *Cli) dispatcher() { + defer c.wg.Done() + + var ( + stream pdpb.PD_QueryRegionClient + streamURL string + streamCtx context.Context + timeoutTimer *time.Timer + resetTimeoutTimer = func() { + if timeoutTimer == nil { + timeoutTimer = time.NewTimer(c.option.Timeout) + } else { + timerutil.SafeResetTimer(timeoutTimer, c.option.Timeout) + } + } + ctx, cancel = context.WithCancel(c.ctx) + ) + + log.Info("[router] dispatcher is started") + defer func() { + log.Info("[router] dispatcher is exiting") + cancel() + if timeoutTimer != nil { + timeoutTimer.Stop() + } + log.Info("[router] dispatcher exited") + }() +batchLoop: + for { + select { + case <-ctx.Done(): + return + default: + } + + // Step 1: Fetch the pending router requests in batch. + err := c.batchController.FetchPendingRequests(ctx, c.requestCh, nil, 0) + if err != nil { + if err == context.Canceled { + log.Info("[router] stop fetching the pending router requests due to context canceled") + } else { + log.Error("[router] failed to fetch the pending router requests", errs.ZapError(err)) + } + return + } + + // Step 2: Choose a stream connection to send the router request. + resetTimeoutTimer() + connectionCtxChoosingLoop: + for { + // Check if the dispatcher is canceled or the timeout timer is triggered. + select { + case <-ctx.Done(): + return + case <-timeoutTimer.C: + log.Error("[router] router stream connection is not ready until timeout, abort the batch") + c.svcDiscovery.ScheduleCheckMemberChanged() + c.batchController.FinishCollectedRequests(requestFinisher(nil), err) + continue batchLoop + default: + } + // Choose a stream connection to send the router request later. + connectionCtx := c.conCtxMgr.GetConnectionCtx() + if connectionCtx == nil { + log.Info("[router] router stream connection is not ready") + c.scheduleUpdateConnection() + continue connectionCtxChoosingLoop + } + streamCtx, streamURL, stream = connectionCtx.Ctx, connectionCtx.StreamURL, connectionCtx.Stream + // Check if the stream connection is canceled. + select { + case <-streamCtx.Done(): + log.Info("[router] router stream connection is canceled", zap.String("stream-url", streamURL)) + c.conCtxMgr.Release(streamURL) + continue connectionCtxChoosingLoop + default: + } + // The stream connection is ready, break the loop. + break connectionCtxChoosingLoop + } + + // Step 3: Dispatch the router requests to the stream connection. + // TODO: timeout handling if the stream takes too long to process the requests. + err = c.processRequests(stream) + if err != nil { + if !c.handleProcessRequestError(ctx, streamURL, err) { + return + } + } + } +} + +func (c *Cli) processRequests(stream pdpb.PD_QueryRegionClient) error { + var ( + requests = c.batchController.GetCollectedRequests() + traceRegions = make([]*trace.Region, 0, len(requests)) + spans = make([]opentracing.Span, 0, len(requests)) + ) + for _, req := range requests { + traceRegions = append(traceRegions, trace.StartRegion(req.requestCtx, "pdclient.regionReqSend")) + if span := opentracing.SpanFromContext(req.requestCtx); span != nil && span.Tracer() != nil { + spans = append(spans, span.Tracer().StartSpan("pdclient.processRegionRequests", opentracing.ChildOf(span.Context()))) + } + } + defer func() { + for i := range spans { + spans[i].Finish() + } + for i := range traceRegions { + traceRegions[i].End() + } + }() + + queryReq := &pdpb.QueryRegionRequest{ + Header: &pdpb.RequestHeader{ + ClusterId: c.svcDiscovery.GetClusterID(), + }, + Keys: make([][]byte, 0, len(requests)), + PrevKeys: make([][]byte, 0, len(requests)), + Ids: make([]uint64, 0, len(requests)), + } + for _, req := range requests { + if !queryReq.NeedBuckets && req.needBuckets { + queryReq.NeedBuckets = true + } + if req.key != nil { + queryReq.Keys = append(queryReq.Keys, req.key) + } else if req.prevKey != nil { + queryReq.PrevKeys = append(queryReq.PrevKeys, req.prevKey) + } else if req.id != 0 { + queryReq.Ids = append(queryReq.Ids, req.id) + } else { + panic("invalid region query request received") + } + } + err := stream.Send(queryReq) + if err != nil { + return err + } + resp, err := stream.Recv() + if err != nil { + return err + } + c.doneCollectedRequests(resp) + return nil +} + +func (c *Cli) handleProcessRequestError( + ctx context.Context, + streamURL string, + err error, +) bool { + log.Error("[router] failed to process the router requests", + zap.String("stream-url", streamURL), + errs.ZapError(err)) + c.cancelCollectedRequests(err) + + select { + case <-ctx.Done(): + return false + default: + } + + // Delete the stream connection context. + c.conCtxMgr.Release(streamURL) + if errs.IsLeaderChange(err) { + // If the leader changes, we better call `CheckMemberChanged` blockingly to + // ensure the next round of router requests can be sent to the new leader. + if err := c.bo.Exec(ctx, c.svcDiscovery.CheckMemberChanged); err != nil { + select { + case <-ctx.Done(): + return false + default: + } + } + } else { + // For other errors, we can just schedule a member change check asynchronously. + c.svcDiscovery.ScheduleCheckMemberChanged() + } + + return true +} diff --git a/client/clients/router/request.go b/client/clients/router/request.go new file mode 100644 index 00000000000..2e1c2e97aa5 --- /dev/null +++ b/client/clients/router/request.go @@ -0,0 +1,115 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package router + +import ( + "context" + "runtime/trace" + "sync" + + "github.com/pingcap/errors" + + "github.com/tikv/pd/client/opt" +) + +// Request is a region info request. +type Request struct { + requestCtx context.Context + clientCtx context.Context + + key []byte + prevKey []byte + id uint64 + needBuckets bool + + done chan error + // region will be set after the request is done. + region *Region + + // Runtime fields. + pool *sync.Pool +} + +func (req *Request) tryDone(err error) { + select { + case req.done <- err: + default: + } +} + +func (req *Request) wait() (*Region, error) { + // TODO: introduce the metrics. + select { + case err := <-req.done: + defer req.pool.Put(req) + defer trace.StartRegion(req.requestCtx, "pdclient.regionReqDone").End() + if err != nil { + return nil, errors.WithStack(err) + } + return req.region, nil + case <-req.requestCtx.Done(): + return nil, errors.WithStack(req.requestCtx.Err()) + case <-req.clientCtx.Done(): + return nil, errors.WithStack(req.clientCtx.Err()) + } +} + +// GetRegion implements the Client interface. +func (c *Cli) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOption) (*Region, error) { + req := c.newRequest(ctx) + req.key = key + options := &opt.GetRegionOp{} + for _, opt := range opts { + opt(options) + } + req.needBuckets = options.NeedBuckets + + c.requestCh <- req + return req.wait() +} + +// GetRegionFromMember implements the Client interface. +func (c *Cli) GetRegionFromMember(ctx context.Context, key []byte, _ []string, opts ...opt.GetRegionOption) (*Region, error) { + // Before we support the follower stream connection, this method is equivalent to `GetRegion`. + return c.GetRegion(ctx, key, opts...) +} + +// GetPrevRegion implements the Client interface. +func (c *Cli) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOption) (*Region, error) { + req := c.newRequest(ctx) + req.prevKey = key + options := &opt.GetRegionOp{} + for _, opt := range opts { + opt(options) + } + req.needBuckets = options.NeedBuckets + + c.requestCh <- req + return req.wait() +} + +// GetRegionByID implements the Client interface. +func (c *Cli) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt.GetRegionOption) (*Region, error) { + req := c.newRequest(ctx) + req.id = regionID + options := &opt.GetRegionOp{} + for _, opt := range opts { + opt(options) + } + req.needBuckets = options.NeedBuckets + + c.requestCh <- req + return req.wait() +} diff --git a/client/errs/errno.go b/client/errs/errno.go index 99a426d0776..8f81d2d6777 100644 --- a/client/errs/errno.go +++ b/client/errs/errno.go @@ -70,6 +70,7 @@ var ( ErrClientFindGroupByKeyspaceID = errors.Normalize("can't find keyspace group by keyspace id", errors.RFCCodeText("PD:client:ErrClientFindGroupByKeyspaceID")) ErrClientWatchGCSafePointV2Stream = errors.Normalize("watch gc safe point v2 stream failed", errors.RFCCodeText("PD:client:ErrClientWatchGCSafePointV2Stream")) ErrCircuitBreakerOpen = errors.Normalize("circuit breaker is open", errors.RFCCodeText("PD:client:ErrCircuitBreakerOpen")) + ErrClientRegionNotFound = errors.Normalize("region %d not found", errors.RFCCodeText("PD:client:ErrClientRegionNotFound")) ) // grpcutil errors diff --git a/client/go.mod b/client/go.mod index 78aef084ff7..a84bf303be1 100644 --- a/client/go.mod +++ b/client/go.mod @@ -10,7 +10,7 @@ require ( github.com/opentracing/opentracing-go v1.2.0 github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 - github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 + github.com/pingcap/kvproto v0.0.0-20250117122752-2b87602a94a1 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 github.com/prometheus/client_golang v1.20.5 github.com/stretchr/testify v1.9.0 diff --git a/client/go.sum b/client/go.sum index 4cca5ba3ad5..2873e4f550c 100644 --- a/client/go.sum +++ b/client/go.sum @@ -49,8 +49,8 @@ github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c h1:xpW9bvK+HuuTm github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c/go.mod h1:X2r9ueLEUZgtx2cIogM0v4Zj5uvvzhuuiu7Pn8HzMPg= github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 h1:tdMsjOqUR7YXHoBitzdebTvOjs/swniBTOLy5XiMtuE= github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86/go.mod h1:exzhVYca3WRtd6gclGNErRWb1qEgff3LYta0LvRmON4= -github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 h1:xYNSJjYNur4Dr5bV+9BXK9n5E0T1zlcAN25XX68+mOg= -github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037/go.mod h1:rXxWk2UnwfUhLXha1jxRWPADw9eMZGWEWCg92Tgmb/8= +github.com/pingcap/kvproto v0.0.0-20250117122752-2b87602a94a1 h1:rTAyiswGyWSGHJVa4Mkhdi8YfGqfA4LrUVKsH9nrJ8E= +github.com/pingcap/kvproto v0.0.0-20250117122752-2b87602a94a1/go.mod h1:rXxWk2UnwfUhLXha1jxRWPADw9eMZGWEWCg92Tgmb/8= github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 h1:HR/ylkkLmGdSSDaD8IDP+SZrdhV1Kibl9KrHxJ9eciw= github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/client/inner_client.go b/client/inner_client.go index 7ce8f42386e..dbd7b62374e 100644 --- a/client/inner_client.go +++ b/client/inner_client.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" + "github.com/tikv/pd/client/clients/router" "github.com/tikv/pd/client/clients/tso" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/metrics" @@ -45,6 +46,7 @@ type innerClient struct { serviceDiscovery sd.ServiceDiscovery tokenDispatcher *tokenDispatcher + routerClient *router.Cli // For service mode switching. serviceModeKeeper @@ -69,6 +71,7 @@ func (c *innerClient) init(updateKeyspaceIDCb sd.UpdateKeyspaceIDFunc) error { } return err } + c.routerClient = router.NewClient(c.ctx, c.serviceDiscovery, c.option) return nil } diff --git a/client/pkg/utils/timerutil/util.go b/client/pkg/utils/timerutil/util.go new file mode 100644 index 00000000000..9fee6c112b7 --- /dev/null +++ b/client/pkg/utils/timerutil/util.go @@ -0,0 +1,32 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package timerutil + +import "time" + +// SafeResetTimer is used to reset timer safely. +// Before Go 1.23, the only safe way to use Reset was to call Timer.Stop and explicitly drain the timer first. +// We need be careful here, see more details in the comments of Timer.Reset. +// https://pkg.go.dev/time@master#Timer.Reset +func SafeResetTimer(t *time.Timer, d time.Duration) { + // Stop the timer if it's not stopped. + if !t.Stop() { + select { + case <-t.C: // try to drain from the channel + default: + } + } + t.Reset(d) +} diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index af8cdc00a7e..91a6d44943e 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -1204,7 +1204,6 @@ func (suite *clientTestSuite) TestGetPrevRegion() { err := suite.regionHeartbeat.Send(req) re.NoError(err) } - time.Sleep(500 * time.Millisecond) for i := range 20 { testutil.Eventually(re, func() bool { r, err := suite.client.GetPrevRegion(context.Background(), []byte{byte(i)}) From a64c5fe56e6e7fb5c3b78d657d513761ea85e27a Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 22 Jan 2025 12:23:38 +0800 Subject: [PATCH 2/4] Setup the test case Signed-off-by: JmPotato --- client/client.go | 22 +++++++ client/clients/router/client.go | 5 +- client/clients/router/request.go | 6 -- client/errs/errno.go | 1 - client/inner_client.go | 8 ++- tests/integrations/client/client_test.go | 82 ++++++++++++++++++++++++ tests/integrations/go.mod | 2 +- 7 files changed, 113 insertions(+), 13 deletions(-) diff --git a/client/client.go b/client/client.go index f1c0a49ff8a..6e9c19b4c31 100644 --- a/client/client.go +++ b/client/client.go @@ -570,6 +570,16 @@ func (c *client) GetMinTS(ctx context.Context) (physical int64, logical int64, e return minTS.Physical, minTS.Logical, nil } +// EnableRouterClient enables the router client. +// This is only for test currently. +func (c *client) EnableRouterClient() { + c.inner.enableRouterClient.Store(true) +} + +func (c *client) isRouterClientEnabled() bool { + return c.inner.enableRouterClient.Load() +} + // GetRegionFromMember implements the RPCClient interface. func (c *client) GetRegionFromMember(ctx context.Context, key []byte, memberURLs []string, _ ...opt.GetRegionOption) (*router.Region, error) { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { @@ -620,6 +630,10 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout) defer cancel() + if c.isRouterClientEnabled() { + return c.inner.routerClient.GetRegion(ctx, key, opts...) + } + options := &opt.GetRegionOp{} for _, opt := range opts { opt(options) @@ -660,6 +674,10 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout) defer cancel() + if c.isRouterClientEnabled() { + return c.inner.routerClient.GetPrevRegion(ctx, key, opts...) + } + options := &opt.GetRegionOp{} for _, opt := range opts { opt(options) @@ -700,6 +718,10 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout) defer cancel() + if c.isRouterClientEnabled() { + return c.inner.routerClient.GetRegionByID(ctx, regionID, opts...) + } + options := &opt.GetRegionOp{} for _, opt := range opts { opt(options) diff --git a/client/clients/router/client.go b/client/clients/router/client.go index 8d0cfd64d2d..8bd44b8b6a3 100644 --- a/client/clients/router/client.go +++ b/client/clients/router/client.go @@ -227,10 +227,7 @@ func requestFinisher(resp *pdpb.QueryRegionResponse) batch.FinisherFunc[*Request } else if req.id != 0 { id = req.id } - region, ok := resp.RegionsById[id] - if !ok { - err = errs.ErrClientRegionNotFound.FastGenByArgs(id) - } else { + if region, ok := resp.RegionsById[id]; ok { req.region = ConvertToRegion(region) } req.tryDone(err) diff --git a/client/clients/router/request.go b/client/clients/router/request.go index 2e1c2e97aa5..cc1ada0a729 100644 --- a/client/clients/router/request.go +++ b/client/clients/router/request.go @@ -80,12 +80,6 @@ func (c *Cli) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOp return req.wait() } -// GetRegionFromMember implements the Client interface. -func (c *Cli) GetRegionFromMember(ctx context.Context, key []byte, _ []string, opts ...opt.GetRegionOption) (*Region, error) { - // Before we support the follower stream connection, this method is equivalent to `GetRegion`. - return c.GetRegion(ctx, key, opts...) -} - // GetPrevRegion implements the Client interface. func (c *Cli) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOption) (*Region, error) { req := c.newRequest(ctx) diff --git a/client/errs/errno.go b/client/errs/errno.go index 8f81d2d6777..99a426d0776 100644 --- a/client/errs/errno.go +++ b/client/errs/errno.go @@ -70,7 +70,6 @@ var ( ErrClientFindGroupByKeyspaceID = errors.Normalize("can't find keyspace group by keyspace id", errors.RFCCodeText("PD:client:ErrClientFindGroupByKeyspaceID")) ErrClientWatchGCSafePointV2Stream = errors.Normalize("watch gc safe point v2 stream failed", errors.RFCCodeText("PD:client:ErrClientWatchGCSafePointV2Stream")) ErrCircuitBreakerOpen = errors.Normalize("circuit breaker is open", errors.RFCCodeText("PD:client:ErrCircuitBreakerOpen")) - ErrClientRegionNotFound = errors.Normalize("region %d not found", errors.RFCCodeText("PD:client:ErrClientRegionNotFound")) ) // grpcutil errors diff --git a/client/inner_client.go b/client/inner_client.go index dbd7b62374e..7a716421af2 100644 --- a/client/inner_client.go +++ b/client/inner_client.go @@ -18,6 +18,7 @@ import ( "context" "crypto/tls" "sync" + "sync/atomic" "time" "go.uber.org/zap" @@ -46,7 +47,12 @@ type innerClient struct { serviceDiscovery sd.ServiceDiscovery tokenDispatcher *tokenDispatcher - routerClient *router.Cli + // The router client is used to get the region info via the streaming gRPC, + // this flag is used to control whether to enable it, currently only used + // in the test. + enableRouterClient atomic.Bool + routerClient *router.Cli + // For service mode switching. serviceModeKeeper diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index 91a6d44943e..e8ee7ceb693 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -36,6 +36,7 @@ import ( "github.com/stretchr/testify/suite" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/goleak" + "golang.org/x/exp/rand" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -1105,6 +1106,10 @@ func bootstrapServer(re *require.Assertions, header *pdpb.RequestHeader, client re.Equal(pdpb.ErrorType_OK, resp.GetHeader().GetError().GetType()) } +func (suite *clientTestSuite) SetupTest() { + suite.grpcSvr.DirectlyGetRaftCluster().ResetRegionCache() +} + func (suite *clientTestSuite) TestGetRegion() { re := suite.Require() regionID := regionIDAllocator.alloc() @@ -1337,6 +1342,83 @@ func (suite *clientTestSuite) TestGetRegionByID() { }) } +func (suite *clientTestSuite) TestGetRegionConcurrently() { + suite.client.(interface{ EnableRouterClient() }).EnableRouterClient() + + re := suite.Require() + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + + regions := make([]*metapb.Region, 0, 2) + for i := range 2 { + regionID := regionIDAllocator.alloc() + region := &metapb.Region{ + Id: regionID, + RegionEpoch: &metapb.RegionEpoch{ + ConfVer: 1, + Version: 1, + }, + StartKey: []byte{byte(i)}, + EndKey: []byte{byte(i + 1)}, + Peers: peers, + } + re.NoError(suite.regionHeartbeat.Send(&pdpb.RegionHeartbeatRequest{ + Header: newHeader(), + Region: region, + Leader: peers[0], + })) + regions = append(regions, region) + } + + const concurrency = 1000 + + wg := sync.WaitGroup{} + wg.Add(concurrency) + for range concurrency { + go func() { + defer wg.Done() + switch rand.Intn(3) { + case 0: + region := regions[0] + testutil.Eventually(re, func() bool { + r, err := suite.client.GetRegion(ctx, region.GetStartKey()) + re.NoError(err) + if r == nil { + return false + } + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) && + r.Buckets == nil + }) + case 1: + testutil.Eventually(re, func() bool { + r, err := suite.client.GetPrevRegion(ctx, regions[1].GetStartKey()) + re.NoError(err) + if r == nil { + return false + } + return reflect.DeepEqual(regions[0], r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) && + r.Buckets == nil + }) + case 2: + region := regions[0] + testutil.Eventually(re, func() bool { + r, err := suite.client.GetRegionByID(ctx, region.GetId()) + re.NoError(err) + if r == nil { + return false + } + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) && + r.Buckets == nil + }) + } + }() + } + wg.Wait() +} + func (suite *clientTestSuite) TestGetStore() { re := suite.Require() cluster := suite.srv.GetRaftCluster() diff --git a/tests/integrations/go.mod b/tests/integrations/go.mod index fca5b54bb07..40c5350c18c 100644 --- a/tests/integrations/go.mod +++ b/tests/integrations/go.mod @@ -25,6 +25,7 @@ require ( go.etcd.io/etcd/client/v3 v3.5.15 go.uber.org/goleak v1.3.0 go.uber.org/zap v1.27.0 + golang.org/x/exp v0.0.0-20230711005742-c3f37128e5a4 google.golang.org/grpc v1.62.1 gorm.io/driver/mysql v1.4.5 gorm.io/gorm v1.24.3 @@ -187,7 +188,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/arch v0.8.0 // indirect golang.org/x/crypto v0.24.0 // indirect - golang.org/x/exp v0.0.0-20230711005742-c3f37128e5a4 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/net v0.26.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect From a84d793b574727c94f00741ce15d3b431ce8649c Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 22 Jan 2025 17:27:09 +0800 Subject: [PATCH 3/4] Put routerClient inside serviceModeKeeper Signed-off-by: JmPotato --- client/client.go | 23 ++++++++++++----------- client/clients/router/client.go | 3 ++- client/inner_client.go | 17 +++++++++-------- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/client/client.go b/client/client.go index 6e9c19b4c31..e5f6442780b 100644 --- a/client/client.go +++ b/client/client.go @@ -143,12 +143,11 @@ var _ Client = (*client)(nil) // serviceModeKeeper is for service mode switching. type serviceModeKeeper struct { - // RMutex here is for the future usage that there might be multiple goroutines - // triggering service mode switching concurrently. sync.RWMutex serviceMode pdpb.ServiceMode tsoClient *tso.Cli tsoSvcDiscovery sd.ServiceDiscovery + routerClient *router.Cli } func (k *serviceModeKeeper) close() { @@ -573,11 +572,13 @@ func (c *client) GetMinTS(ctx context.Context) (physical int64, logical int64, e // EnableRouterClient enables the router client. // This is only for test currently. func (c *client) EnableRouterClient() { - c.inner.enableRouterClient.Store(true) + c.inner.initRouterClient() } -func (c *client) isRouterClientEnabled() bool { - return c.inner.enableRouterClient.Load() +func (c *client) getRouterClient() *router.Cli { + c.inner.RLock() + defer c.inner.RUnlock() + return c.inner.routerClient } // GetRegionFromMember implements the RPCClient interface. @@ -630,8 +631,8 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout) defer cancel() - if c.isRouterClientEnabled() { - return c.inner.routerClient.GetRegion(ctx, key, opts...) + if routerClient := c.getRouterClient(); routerClient != nil { + return routerClient.GetRegion(ctx, key, opts...) } options := &opt.GetRegionOp{} @@ -674,8 +675,8 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout) defer cancel() - if c.isRouterClientEnabled() { - return c.inner.routerClient.GetPrevRegion(ctx, key, opts...) + if routerClient := c.getRouterClient(); routerClient != nil { + return routerClient.GetPrevRegion(ctx, key, opts...) } options := &opt.GetRegionOp{} @@ -718,8 +719,8 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout) defer cancel() - if c.isRouterClientEnabled() { - return c.inner.routerClient.GetRegionByID(ctx, regionID, opts...) + if routerClient := c.getRouterClient(); routerClient != nil { + return routerClient.GetRegionByID(ctx, regionID, opts...) } options := &opt.GetRegionOp{} diff --git a/client/clients/router/client.go b/client/clients/router/client.go index 8bd44b8b6a3..240e403d036 100644 --- a/client/clients/router/client.go +++ b/client/clients/router/client.go @@ -265,9 +265,10 @@ func (c *Cli) getLeaderURL() string { func (c *Cli) updateLeaderURL(url string) error { oldURL := c.getLeaderURL() - if !c.leaderURL.CompareAndSwap(oldURL, url) { + if oldURL == url { return nil } + c.leaderURL.Store(url) c.scheduleUpdateConnection() log.Info("[router] switch the router leader serving url", diff --git a/client/inner_client.go b/client/inner_client.go index 7a716421af2..181ee2c9d52 100644 --- a/client/inner_client.go +++ b/client/inner_client.go @@ -18,7 +18,6 @@ import ( "context" "crypto/tls" "sync" - "sync/atomic" "time" "go.uber.org/zap" @@ -47,12 +46,6 @@ type innerClient struct { serviceDiscovery sd.ServiceDiscovery tokenDispatcher *tokenDispatcher - // The router client is used to get the region info via the streaming gRPC, - // this flag is used to control whether to enable it, currently only used - // in the test. - enableRouterClient atomic.Bool - routerClient *router.Cli - // For service mode switching. serviceModeKeeper @@ -77,11 +70,19 @@ func (c *innerClient) init(updateKeyspaceIDCb sd.UpdateKeyspaceIDFunc) error { } return err } - c.routerClient = router.NewClient(c.ctx, c.serviceDiscovery, c.option) return nil } +func (c *innerClient) initRouterClient() { + c.Lock() + defer c.Unlock() + if c.routerClient != nil { + return + } + c.routerClient = router.NewClient(c.ctx, c.serviceDiscovery, c.option) +} + func (c *innerClient) setServiceMode(newMode pdpb.ServiceMode) { c.Lock() defer c.Unlock() From 7d6da813b3b64488a2405cd26515127a71dcb046 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Fri, 24 Jan 2025 16:47:29 +0800 Subject: [PATCH 4/4] Address the comments Signed-off-by: JmPotato --- client/clients/router/client.go | 11 +++++--- client/clients/router/request.go | 11 +++++--- client/pkg/connectionctx/manager.go | 3 +-- client/pkg/utils/timerutil/util.go | 32 ------------------------ tests/integrations/client/client_test.go | 2 +- tests/integrations/go.mod | 2 +- 6 files changed, 18 insertions(+), 43 deletions(-) delete mode 100644 client/pkg/utils/timerutil/util.go diff --git a/client/clients/router/client.go b/client/clients/router/client.go index 240e403d036..0038c42dff8 100644 --- a/client/clients/router/client.go +++ b/client/clients/router/client.go @@ -36,7 +36,6 @@ import ( "github.com/tikv/pd/client/pkg/batch" cctx "github.com/tikv/pd/client/pkg/connectionctx" "github.com/tikv/pd/client/pkg/retry" - "github.com/tikv/pd/client/pkg/utils/timerutil" sd "github.com/tikv/pd/client/servicediscovery" ) @@ -176,7 +175,10 @@ func NewClient( conCtxMgr: cctx.NewManager[pdpb.PD_QueryRegionClient](), updateConnectionCh: make(chan struct{}, 1), bo: retry.InitialBackoffer( - sd.UpdateMemberBackOffBaseTime, sd.UpdateMemberTimeout, sd.UpdateMemberBackOffBaseTime), + sd.UpdateMemberBackOffBaseTime, + sd.UpdateMemberMaxBackoffTime, + sd.UpdateMemberTimeout, + ), reqPool: &sync.Pool{ New: func() any { return &Request{ @@ -189,6 +191,7 @@ func NewClient( } c.leaderURL.Store(svcDiscovery.GetServingURL()) c.svcDiscovery.ExecAndAddLeaderSwitchedCallback(c.updateLeaderURL) + c.svcDiscovery.AddMembersChangedCallback(c.scheduleUpdateConnection) c.wg.Add(2) go c.connectionDaemon() @@ -353,7 +356,7 @@ func (c *Cli) dispatcher() { if timeoutTimer == nil { timeoutTimer = time.NewTimer(c.option.Timeout) } else { - timerutil.SafeResetTimer(timeoutTimer, c.option.Timeout) + timeoutTimer.Reset(c.option.Timeout) } } ctx, cancel = context.WithCancel(c.ctx) @@ -406,7 +409,7 @@ batchLoop: connectionCtx := c.conCtxMgr.GetConnectionCtx() if connectionCtx == nil { log.Info("[router] router stream connection is not ready") - c.scheduleUpdateConnection() + c.updateConnection(ctx) continue connectionCtxChoosingLoop } streamCtx, streamURL, stream = connectionCtx.Ctx, connectionCtx.StreamURL, connectionCtx.Stream diff --git a/client/clients/router/request.go b/client/clients/router/request.go index cc1ada0a729..4578514597d 100644 --- a/client/clients/router/request.go +++ b/client/clients/router/request.go @@ -29,9 +29,14 @@ type Request struct { requestCtx context.Context clientCtx context.Context - key []byte - prevKey []byte - id uint64 + // Key field represents this is a `GetRegion` request. + key []byte + // PrevKey field represents this is a `GetPrevRegion` request. + prevKey []byte + // ID field represents this is a `GetRegionByID` request. + id uint64 + + // NeedBuckets field represents whether the request needs to get the region buckets. needBuckets bool done chan error diff --git a/client/pkg/connectionctx/manager.go b/client/pkg/connectionctx/manager.go index 04c1eb13d3a..fede8baf723 100644 --- a/client/pkg/connectionctx/manager.go +++ b/client/pkg/connectionctx/manager.go @@ -16,9 +16,8 @@ package connectionctx import ( "context" + "math/rand" "sync" - - "golang.org/x/exp/rand" ) type connectionCtx[T any] struct { diff --git a/client/pkg/utils/timerutil/util.go b/client/pkg/utils/timerutil/util.go deleted file mode 100644 index 9fee6c112b7..00000000000 --- a/client/pkg/utils/timerutil/util.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2024 TiKV Project Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package timerutil - -import "time" - -// SafeResetTimer is used to reset timer safely. -// Before Go 1.23, the only safe way to use Reset was to call Timer.Stop and explicitly drain the timer first. -// We need be careful here, see more details in the comments of Timer.Reset. -// https://pkg.go.dev/time@master#Timer.Reset -func SafeResetTimer(t *time.Timer, d time.Duration) { - // Stop the timer if it's not stopped. - if !t.Stop() { - select { - case <-t.C: // try to drain from the channel - default: - } - } - t.Reset(d) -} diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index e8ee7ceb693..0f972cb2b8f 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -20,6 +20,7 @@ import ( "encoding/json" "fmt" "math" + "math/rand" "os" "path" "reflect" @@ -36,7 +37,6 @@ import ( "github.com/stretchr/testify/suite" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/goleak" - "golang.org/x/exp/rand" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" diff --git a/tests/integrations/go.mod b/tests/integrations/go.mod index 40c5350c18c..fca5b54bb07 100644 --- a/tests/integrations/go.mod +++ b/tests/integrations/go.mod @@ -25,7 +25,6 @@ require ( go.etcd.io/etcd/client/v3 v3.5.15 go.uber.org/goleak v1.3.0 go.uber.org/zap v1.27.0 - golang.org/x/exp v0.0.0-20230711005742-c3f37128e5a4 google.golang.org/grpc v1.62.1 gorm.io/driver/mysql v1.4.5 gorm.io/gorm v1.24.3 @@ -188,6 +187,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/arch v0.8.0 // indirect golang.org/x/crypto v0.24.0 // indirect + golang.org/x/exp v0.0.0-20230711005742-c3f37128e5a4 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/net v0.26.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect