diff --git a/client/client.go b/client/client.go index fa0a1473ba7..e5f6442780b 100644 --- a/client/client.go +++ b/client/client.go @@ -143,12 +143,11 @@ var _ Client = (*client)(nil) // serviceModeKeeper is for service mode switching. type serviceModeKeeper struct { - // RMutex here is for the future usage that there might be multiple goroutines - // triggering service mode switching concurrently. sync.RWMutex serviceMode pdpb.ServiceMode tsoClient *tso.Cli tsoSvcDiscovery sd.ServiceDiscovery + routerClient *router.Cli } func (k *serviceModeKeeper) close() { @@ -570,21 +569,16 @@ func (c *client) GetMinTS(ctx context.Context) (physical int64, logical int64, e return minTS.Physical, minTS.Logical, nil } -func handleRegionResponse(res *pdpb.GetRegionResponse) *router.Region { - if res.Region == nil { - return nil - } +// EnableRouterClient enables the router client. +// This is only for test currently. +func (c *client) EnableRouterClient() { + c.inner.initRouterClient() +} - r := &router.Region{ - Meta: res.Region, - Leader: res.Leader, - PendingPeers: res.PendingPeers, - Buckets: res.Buckets, - } - for _, s := range res.DownPeers { - r.DownPeers = append(r.DownPeers, s.Peer) - } - return r +func (c *client) getRouterClient() *router.Cli { + c.inner.RLock() + defer c.inner.RUnlock() + return c.inner.routerClient } // GetRegionFromMember implements the RPCClient interface. @@ -623,7 +617,7 @@ func (c *client) GetRegionFromMember(ctx context.Context, key []byte, memberURLs errorMsg := fmt.Sprintf("[pd] can't get region info from member URLs: %+v", memberURLs) return nil, errors.WithStack(errors.New(errorMsg)) } - return handleRegionResponse(resp), nil + return router.ConvertToRegion(resp), nil } // GetRegion implements the RPCClient interface. @@ -637,6 +631,10 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout) defer cancel() + if routerClient := c.getRouterClient(); routerClient != nil { + return routerClient.GetRegion(ctx, key, opts...) + } + options := &opt.GetRegionOp{} for _, opt := range opts { opt(options) @@ -663,7 +661,7 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio if err = c.respForErr(metrics.CmdFailedDurationGetRegion, start, err, resp.GetHeader()); err != nil { return nil, err } - return handleRegionResponse(resp), nil + return router.ConvertToRegion(resp), nil } // GetPrevRegion implements the RPCClient interface. @@ -677,6 +675,10 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout) defer cancel() + if routerClient := c.getRouterClient(); routerClient != nil { + return routerClient.GetPrevRegion(ctx, key, opts...) + } + options := &opt.GetRegionOp{} for _, opt := range opts { opt(options) @@ -703,7 +705,7 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR if err = c.respForErr(metrics.CmdFailedDurationGetPrevRegion, start, err, resp.GetHeader()); err != nil { return nil, err } - return handleRegionResponse(resp), nil + return router.ConvertToRegion(resp), nil } // GetRegionByID implements the RPCClient interface. @@ -717,6 +719,10 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout) defer cancel() + if routerClient := c.getRouterClient(); routerClient != nil { + return routerClient.GetRegionByID(ctx, regionID, opts...) + } + options := &opt.GetRegionOp{} for _, opt := range opts { opt(options) @@ -744,7 +750,7 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt if err = c.respForErr(metrics.CmdFailedDurationGetRegionByID, start, err, resp.GetHeader()); err != nil { return nil, err } - return handleRegionResponse(resp), nil + return router.ConvertToRegion(resp), nil } // ScanRegions implements the RPCClient interface. diff --git a/client/clients/router/client.go b/client/clients/router/client.go index 48cebfa950e..240e403d036 100644 --- a/client/clients/router/client.go +++ b/client/clients/router/client.go @@ -18,12 +18,31 @@ import ( "context" "encoding/hex" "net/url" + "runtime/trace" + "sync" + "sync/atomic" + "time" + + "github.com/opentracing/opentracing-go" + "go.uber.org/zap" + "google.golang.org/grpc" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" + "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/opt" + "github.com/tikv/pd/client/pkg/batch" + cctx "github.com/tikv/pd/client/pkg/connectionctx" + "github.com/tikv/pd/client/pkg/retry" + "github.com/tikv/pd/client/pkg/utils/timerutil" + sd "github.com/tikv/pd/client/servicediscovery" ) +// defaultMaxRouterRequestBatchSize is the default max size of the router request batch. +const defaultMaxRouterRequestBatchSize = 10000 + // Region contains information of a region's meta and its peers. type Region struct { Meta *metapb.Region @@ -33,6 +52,33 @@ type Region struct { Buckets *metapb.Buckets } +type regionResponse interface { + GetRegion() *metapb.Region + GetLeader() *metapb.Peer + GetDownPeers() []*pdpb.PeerStats + GetPendingPeers() []*metapb.Peer + GetBuckets() *metapb.Buckets +} + +// ConvertToRegion converts the region response to the region. +func ConvertToRegion(res regionResponse) *Region { + region := res.GetRegion() + if region == nil { + return nil + } + + r := &Region{ + Meta: region, + Leader: res.GetLeader(), + PendingPeers: res.GetPendingPeers(), + Buckets: res.GetBuckets(), + } + for _, s := range res.GetDownPeers() { + r.DownPeers = append(r.DownPeers, s.Peer) + } + return r +} + // KeyRange defines a range of keys in bytes. type KeyRange struct { StartKey []byte @@ -92,3 +138,388 @@ type Client interface { // The returned regions are flattened, even there are key ranges located in the same region, only one region will be returned. BatchScanRegions(ctx context.Context, keyRanges []KeyRange, limit int, opts ...opt.GetRegionOption) ([]*Region, error) } + +// Cli is the implementation of the router client. +type Cli struct { + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + option *opt.Option + + svcDiscovery sd.ServiceDiscovery + // leaderURL is the URL of the router leader. + leaderURL atomic.Value + // conCtxMgr is used to store the context of the router stream connection(s). + conCtxMgr *cctx.Manager[pdpb.PD_QueryRegionClient] + // updateConnectionCh is used to trigger the connection update actively. + updateConnectionCh chan struct{} + // bo is the backoffer for the router client. + bo *retry.Backoffer + + reqPool *sync.Pool + requestCh chan *Request + batchController *batch.Controller[*Request] +} + +// NewClient returns a new router client. +func NewClient( + ctx context.Context, + svcDiscovery sd.ServiceDiscovery, + option *opt.Option, +) *Cli { + ctx, cancel := context.WithCancel(ctx) + c := &Cli{ + ctx: ctx, + cancel: cancel, + svcDiscovery: svcDiscovery, + option: option, + conCtxMgr: cctx.NewManager[pdpb.PD_QueryRegionClient](), + updateConnectionCh: make(chan struct{}, 1), + bo: retry.InitialBackoffer( + sd.UpdateMemberBackOffBaseTime, sd.UpdateMemberTimeout, sd.UpdateMemberBackOffBaseTime), + reqPool: &sync.Pool{ + New: func() any { + return &Request{ + done: make(chan error, 1), + } + }, + }, + requestCh: make(chan *Request, defaultMaxRouterRequestBatchSize*2), + batchController: batch.NewController(defaultMaxRouterRequestBatchSize, requestFinisher(nil), nil), + } + c.leaderURL.Store(svcDiscovery.GetServingURL()) + c.svcDiscovery.ExecAndAddLeaderSwitchedCallback(c.updateLeaderURL) + + c.wg.Add(2) + go c.connectionDaemon() + go c.dispatcher() + + return c +} + +func (c *Cli) newRequest(ctx context.Context) *Request { + req := c.reqPool.Get().(*Request) + req.requestCtx = ctx + req.clientCtx = c.ctx + req.pool = c.reqPool + + return req +} + +func requestFinisher(resp *pdpb.QueryRegionResponse) batch.FinisherFunc[*Request] { + var keyIdx, prevKeyIdx int + return func(_ int, req *Request, err error) { + requestCtx := req.requestCtx + defer trace.StartRegion(requestCtx, "pdclient.regionReqDone").End() + + if err != nil { + req.tryDone(err) + return + } + + var id uint64 + if req.key != nil { + id = resp.KeyIdMap[keyIdx] + keyIdx++ + } else if req.prevKey != nil { + id = resp.PrevKeyIdMap[prevKeyIdx] + prevKeyIdx++ + } else if req.id != 0 { + id = req.id + } + if region, ok := resp.RegionsById[id]; ok { + req.region = ConvertToRegion(region) + } + req.tryDone(err) + } +} + +func (c *Cli) cancelCollectedRequests(err error) { + c.batchController.FinishCollectedRequests(requestFinisher(nil), err) +} + +func (c *Cli) doneCollectedRequests(resp *pdpb.QueryRegionResponse) { + c.batchController.FinishCollectedRequests(requestFinisher(resp), nil) +} + +// Close closes the router client. +func (c *Cli) Close() { + if c == nil { + return + } + log.Info("[router] closing router client") + + c.cancel() + c.wg.Wait() + + log.Info("[router] router client is closed") +} + +func (c *Cli) getLeaderURL() string { + url := c.leaderURL.Load() + if url == nil { + return "" + } + return url.(string) +} + +func (c *Cli) updateLeaderURL(url string) error { + oldURL := c.getLeaderURL() + if oldURL == url { + return nil + } + c.leaderURL.Store(url) + c.scheduleUpdateConnection() + + log.Info("[router] switch the router leader serving url", + zap.String("old-url", oldURL), zap.String("new-url", url)) + return nil +} + +// getLeaderClientConn returns the leader gRPC client connection. +func (c *Cli) getLeaderClientConn() (*grpc.ClientConn, string) { + url := c.getLeaderURL() + if len(url) == 0 { + c.svcDiscovery.ScheduleCheckMemberChanged() + return nil, "" + } + cc, ok := c.svcDiscovery.GetClientConns().Load(url) + if !ok { + return nil, url + } + return cc.(*grpc.ClientConn), url +} + +// scheduleUpdateConnection is used to schedule an update to the connection(s). +func (c *Cli) scheduleUpdateConnection() { + select { + case c.updateConnectionCh <- struct{}{}: + default: + } +} + +// connectionDaemon is used to update the router leader/primary/backup connection(s) in background. +// It aims to provide a seamless connection updating for the router client to keep providing the +// router service without interruption. +func (c *Cli) connectionDaemon() { + defer c.wg.Done() + updaterCtx, updaterCancel := context.WithCancel(c.ctx) + defer updaterCancel() + updateTicker := time.NewTicker(sd.MemberUpdateInterval) + defer updateTicker.Stop() + + log.Info("[router] connection daemon is started") + for { + c.updateConnection(updaterCtx) + select { + case <-updaterCtx.Done(): + log.Info("[router] connection daemon is exiting") + return + case <-updateTicker.C: + case <-c.updateConnectionCh: + } + } +} + +// updateConnection is used to get the leader client connection and update the connection context if it does not exist before. +func (c *Cli) updateConnection(ctx context.Context) { + cc, url := c.getLeaderClientConn() + if cc == nil || len(url) == 0 { + log.Warn("[router] got an invalid leader client connection", zap.String("url", url)) + return + } + if c.conCtxMgr.Exist(url) { + log.Debug("[router] the router leader remains unchanged", zap.String("url", url)) + return + } + stream, err := pdpb.NewPDClient(cc).QueryRegion(ctx) + if err != nil { + log.Error("[router] failed to create the router stream connection", errs.ZapError(err)) + } + c.conCtxMgr.Store(ctx, url, stream) + // TODO: support the forwarding mechanism for the router client. + // TODO: support sending the router requests to the follower nodes. +} + +func (c *Cli) dispatcher() { + defer c.wg.Done() + + var ( + stream pdpb.PD_QueryRegionClient + streamURL string + streamCtx context.Context + timeoutTimer *time.Timer + resetTimeoutTimer = func() { + if timeoutTimer == nil { + timeoutTimer = time.NewTimer(c.option.Timeout) + } else { + timerutil.SafeResetTimer(timeoutTimer, c.option.Timeout) + } + } + ctx, cancel = context.WithCancel(c.ctx) + ) + + log.Info("[router] dispatcher is started") + defer func() { + log.Info("[router] dispatcher is exiting") + cancel() + if timeoutTimer != nil { + timeoutTimer.Stop() + } + log.Info("[router] dispatcher exited") + }() +batchLoop: + for { + select { + case <-ctx.Done(): + return + default: + } + + // Step 1: Fetch the pending router requests in batch. + err := c.batchController.FetchPendingRequests(ctx, c.requestCh, nil, 0) + if err != nil { + if err == context.Canceled { + log.Info("[router] stop fetching the pending router requests due to context canceled") + } else { + log.Error("[router] failed to fetch the pending router requests", errs.ZapError(err)) + } + return + } + + // Step 2: Choose a stream connection to send the router request. + resetTimeoutTimer() + connectionCtxChoosingLoop: + for { + // Check if the dispatcher is canceled or the timeout timer is triggered. + select { + case <-ctx.Done(): + return + case <-timeoutTimer.C: + log.Error("[router] router stream connection is not ready until timeout, abort the batch") + c.svcDiscovery.ScheduleCheckMemberChanged() + c.batchController.FinishCollectedRequests(requestFinisher(nil), err) + continue batchLoop + default: + } + // Choose a stream connection to send the router request later. + connectionCtx := c.conCtxMgr.GetConnectionCtx() + if connectionCtx == nil { + log.Info("[router] router stream connection is not ready") + c.scheduleUpdateConnection() + continue connectionCtxChoosingLoop + } + streamCtx, streamURL, stream = connectionCtx.Ctx, connectionCtx.StreamURL, connectionCtx.Stream + // Check if the stream connection is canceled. + select { + case <-streamCtx.Done(): + log.Info("[router] router stream connection is canceled", zap.String("stream-url", streamURL)) + c.conCtxMgr.Release(streamURL) + continue connectionCtxChoosingLoop + default: + } + // The stream connection is ready, break the loop. + break connectionCtxChoosingLoop + } + + // Step 3: Dispatch the router requests to the stream connection. + // TODO: timeout handling if the stream takes too long to process the requests. + err = c.processRequests(stream) + if err != nil { + if !c.handleProcessRequestError(ctx, streamURL, err) { + return + } + } + } +} + +func (c *Cli) processRequests(stream pdpb.PD_QueryRegionClient) error { + var ( + requests = c.batchController.GetCollectedRequests() + traceRegions = make([]*trace.Region, 0, len(requests)) + spans = make([]opentracing.Span, 0, len(requests)) + ) + for _, req := range requests { + traceRegions = append(traceRegions, trace.StartRegion(req.requestCtx, "pdclient.regionReqSend")) + if span := opentracing.SpanFromContext(req.requestCtx); span != nil && span.Tracer() != nil { + spans = append(spans, span.Tracer().StartSpan("pdclient.processRegionRequests", opentracing.ChildOf(span.Context()))) + } + } + defer func() { + for i := range spans { + spans[i].Finish() + } + for i := range traceRegions { + traceRegions[i].End() + } + }() + + queryReq := &pdpb.QueryRegionRequest{ + Header: &pdpb.RequestHeader{ + ClusterId: c.svcDiscovery.GetClusterID(), + }, + Keys: make([][]byte, 0, len(requests)), + PrevKeys: make([][]byte, 0, len(requests)), + Ids: make([]uint64, 0, len(requests)), + } + for _, req := range requests { + if !queryReq.NeedBuckets && req.needBuckets { + queryReq.NeedBuckets = true + } + if req.key != nil { + queryReq.Keys = append(queryReq.Keys, req.key) + } else if req.prevKey != nil { + queryReq.PrevKeys = append(queryReq.PrevKeys, req.prevKey) + } else if req.id != 0 { + queryReq.Ids = append(queryReq.Ids, req.id) + } else { + panic("invalid region query request received") + } + } + err := stream.Send(queryReq) + if err != nil { + return err + } + resp, err := stream.Recv() + if err != nil { + return err + } + c.doneCollectedRequests(resp) + return nil +} + +func (c *Cli) handleProcessRequestError( + ctx context.Context, + streamURL string, + err error, +) bool { + log.Error("[router] failed to process the router requests", + zap.String("stream-url", streamURL), + errs.ZapError(err)) + c.cancelCollectedRequests(err) + + select { + case <-ctx.Done(): + return false + default: + } + + // Delete the stream connection context. + c.conCtxMgr.Release(streamURL) + if errs.IsLeaderChange(err) { + // If the leader changes, we better call `CheckMemberChanged` blockingly to + // ensure the next round of router requests can be sent to the new leader. + if err := c.bo.Exec(ctx, c.svcDiscovery.CheckMemberChanged); err != nil { + select { + case <-ctx.Done(): + return false + default: + } + } + } else { + // For other errors, we can just schedule a member change check asynchronously. + c.svcDiscovery.ScheduleCheckMemberChanged() + } + + return true +} diff --git a/client/clients/router/request.go b/client/clients/router/request.go new file mode 100644 index 00000000000..cc1ada0a729 --- /dev/null +++ b/client/clients/router/request.go @@ -0,0 +1,109 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package router + +import ( + "context" + "runtime/trace" + "sync" + + "github.com/pingcap/errors" + + "github.com/tikv/pd/client/opt" +) + +// Request is a region info request. +type Request struct { + requestCtx context.Context + clientCtx context.Context + + key []byte + prevKey []byte + id uint64 + needBuckets bool + + done chan error + // region will be set after the request is done. + region *Region + + // Runtime fields. + pool *sync.Pool +} + +func (req *Request) tryDone(err error) { + select { + case req.done <- err: + default: + } +} + +func (req *Request) wait() (*Region, error) { + // TODO: introduce the metrics. + select { + case err := <-req.done: + defer req.pool.Put(req) + defer trace.StartRegion(req.requestCtx, "pdclient.regionReqDone").End() + if err != nil { + return nil, errors.WithStack(err) + } + return req.region, nil + case <-req.requestCtx.Done(): + return nil, errors.WithStack(req.requestCtx.Err()) + case <-req.clientCtx.Done(): + return nil, errors.WithStack(req.clientCtx.Err()) + } +} + +// GetRegion implements the Client interface. +func (c *Cli) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOption) (*Region, error) { + req := c.newRequest(ctx) + req.key = key + options := &opt.GetRegionOp{} + for _, opt := range opts { + opt(options) + } + req.needBuckets = options.NeedBuckets + + c.requestCh <- req + return req.wait() +} + +// GetPrevRegion implements the Client interface. +func (c *Cli) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOption) (*Region, error) { + req := c.newRequest(ctx) + req.prevKey = key + options := &opt.GetRegionOp{} + for _, opt := range opts { + opt(options) + } + req.needBuckets = options.NeedBuckets + + c.requestCh <- req + return req.wait() +} + +// GetRegionByID implements the Client interface. +func (c *Cli) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt.GetRegionOption) (*Region, error) { + req := c.newRequest(ctx) + req.id = regionID + options := &opt.GetRegionOp{} + for _, opt := range opts { + opt(options) + } + req.needBuckets = options.NeedBuckets + + c.requestCh <- req + return req.wait() +} diff --git a/client/go.mod b/client/go.mod index 78aef084ff7..a84bf303be1 100644 --- a/client/go.mod +++ b/client/go.mod @@ -10,7 +10,7 @@ require ( github.com/opentracing/opentracing-go v1.2.0 github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 - github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 + github.com/pingcap/kvproto v0.0.0-20250117122752-2b87602a94a1 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 github.com/prometheus/client_golang v1.20.5 github.com/stretchr/testify v1.9.0 diff --git a/client/go.sum b/client/go.sum index 4cca5ba3ad5..2873e4f550c 100644 --- a/client/go.sum +++ b/client/go.sum @@ -49,8 +49,8 @@ github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c h1:xpW9bvK+HuuTm github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c/go.mod h1:X2r9ueLEUZgtx2cIogM0v4Zj5uvvzhuuiu7Pn8HzMPg= github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 h1:tdMsjOqUR7YXHoBitzdebTvOjs/swniBTOLy5XiMtuE= github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86/go.mod h1:exzhVYca3WRtd6gclGNErRWb1qEgff3LYta0LvRmON4= -github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 h1:xYNSJjYNur4Dr5bV+9BXK9n5E0T1zlcAN25XX68+mOg= -github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037/go.mod h1:rXxWk2UnwfUhLXha1jxRWPADw9eMZGWEWCg92Tgmb/8= +github.com/pingcap/kvproto v0.0.0-20250117122752-2b87602a94a1 h1:rTAyiswGyWSGHJVa4Mkhdi8YfGqfA4LrUVKsH9nrJ8E= +github.com/pingcap/kvproto v0.0.0-20250117122752-2b87602a94a1/go.mod h1:rXxWk2UnwfUhLXha1jxRWPADw9eMZGWEWCg92Tgmb/8= github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 h1:HR/ylkkLmGdSSDaD8IDP+SZrdhV1Kibl9KrHxJ9eciw= github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/client/inner_client.go b/client/inner_client.go index 7ce8f42386e..181ee2c9d52 100644 --- a/client/inner_client.go +++ b/client/inner_client.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" + "github.com/tikv/pd/client/clients/router" "github.com/tikv/pd/client/clients/tso" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/metrics" @@ -73,6 +74,15 @@ func (c *innerClient) init(updateKeyspaceIDCb sd.UpdateKeyspaceIDFunc) error { return nil } +func (c *innerClient) initRouterClient() { + c.Lock() + defer c.Unlock() + if c.routerClient != nil { + return + } + c.routerClient = router.NewClient(c.ctx, c.serviceDiscovery, c.option) +} + func (c *innerClient) setServiceMode(newMode pdpb.ServiceMode) { c.Lock() defer c.Unlock() diff --git a/client/pkg/utils/timerutil/util.go b/client/pkg/utils/timerutil/util.go new file mode 100644 index 00000000000..9fee6c112b7 --- /dev/null +++ b/client/pkg/utils/timerutil/util.go @@ -0,0 +1,32 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package timerutil + +import "time" + +// SafeResetTimer is used to reset timer safely. +// Before Go 1.23, the only safe way to use Reset was to call Timer.Stop and explicitly drain the timer first. +// We need be careful here, see more details in the comments of Timer.Reset. +// https://pkg.go.dev/time@master#Timer.Reset +func SafeResetTimer(t *time.Timer, d time.Duration) { + // Stop the timer if it's not stopped. + if !t.Stop() { + select { + case <-t.C: // try to drain from the channel + default: + } + } + t.Reset(d) +} diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index af8cdc00a7e..e8ee7ceb693 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -36,6 +36,7 @@ import ( "github.com/stretchr/testify/suite" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/goleak" + "golang.org/x/exp/rand" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -1105,6 +1106,10 @@ func bootstrapServer(re *require.Assertions, header *pdpb.RequestHeader, client re.Equal(pdpb.ErrorType_OK, resp.GetHeader().GetError().GetType()) } +func (suite *clientTestSuite) SetupTest() { + suite.grpcSvr.DirectlyGetRaftCluster().ResetRegionCache() +} + func (suite *clientTestSuite) TestGetRegion() { re := suite.Require() regionID := regionIDAllocator.alloc() @@ -1204,7 +1209,6 @@ func (suite *clientTestSuite) TestGetPrevRegion() { err := suite.regionHeartbeat.Send(req) re.NoError(err) } - time.Sleep(500 * time.Millisecond) for i := range 20 { testutil.Eventually(re, func() bool { r, err := suite.client.GetPrevRegion(context.Background(), []byte{byte(i)}) @@ -1338,6 +1342,83 @@ func (suite *clientTestSuite) TestGetRegionByID() { }) } +func (suite *clientTestSuite) TestGetRegionConcurrently() { + suite.client.(interface{ EnableRouterClient() }).EnableRouterClient() + + re := suite.Require() + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + + regions := make([]*metapb.Region, 0, 2) + for i := range 2 { + regionID := regionIDAllocator.alloc() + region := &metapb.Region{ + Id: regionID, + RegionEpoch: &metapb.RegionEpoch{ + ConfVer: 1, + Version: 1, + }, + StartKey: []byte{byte(i)}, + EndKey: []byte{byte(i + 1)}, + Peers: peers, + } + re.NoError(suite.regionHeartbeat.Send(&pdpb.RegionHeartbeatRequest{ + Header: newHeader(), + Region: region, + Leader: peers[0], + })) + regions = append(regions, region) + } + + const concurrency = 1000 + + wg := sync.WaitGroup{} + wg.Add(concurrency) + for range concurrency { + go func() { + defer wg.Done() + switch rand.Intn(3) { + case 0: + region := regions[0] + testutil.Eventually(re, func() bool { + r, err := suite.client.GetRegion(ctx, region.GetStartKey()) + re.NoError(err) + if r == nil { + return false + } + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) && + r.Buckets == nil + }) + case 1: + testutil.Eventually(re, func() bool { + r, err := suite.client.GetPrevRegion(ctx, regions[1].GetStartKey()) + re.NoError(err) + if r == nil { + return false + } + return reflect.DeepEqual(regions[0], r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) && + r.Buckets == nil + }) + case 2: + region := regions[0] + testutil.Eventually(re, func() bool { + r, err := suite.client.GetRegionByID(ctx, region.GetId()) + re.NoError(err) + if r == nil { + return false + } + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) && + r.Buckets == nil + }) + } + }() + } + wg.Wait() +} + func (suite *clientTestSuite) TestGetStore() { re := suite.Require() cluster := suite.srv.GetRaftCluster() diff --git a/tests/integrations/go.mod b/tests/integrations/go.mod index fca5b54bb07..40c5350c18c 100644 --- a/tests/integrations/go.mod +++ b/tests/integrations/go.mod @@ -25,6 +25,7 @@ require ( go.etcd.io/etcd/client/v3 v3.5.15 go.uber.org/goleak v1.3.0 go.uber.org/zap v1.27.0 + golang.org/x/exp v0.0.0-20230711005742-c3f37128e5a4 google.golang.org/grpc v1.62.1 gorm.io/driver/mysql v1.4.5 gorm.io/gorm v1.24.3 @@ -187,7 +188,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/arch v0.8.0 // indirect golang.org/x/crypto v0.24.0 // indirect - golang.org/x/exp v0.0.0-20230711005742-c3f37128e5a4 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/net v0.26.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect