From 897f5cecea36c1417f43bec3ef1f61861d15237e Mon Sep 17 00:00:00 2001 From: Arsene Date: Wed, 18 Dec 2024 19:48:50 +0000 Subject: [PATCH] refactor: refactor the implementation (#37) --- config.go | 15 +++++++++++-- internal/tcp/tcp.go | 43 +++++++++++++++++++++++++++++++++++- node.go | 53 +++++++++++++++++++++++++++++++++++---------- node_test.go | 2 ++ 4 files changed, 98 insertions(+), 15 deletions(-) diff --git a/config.go b/config.go index 540d1ae..7fad68e 100644 --- a/config.go +++ b/config.go @@ -40,6 +40,7 @@ type Config struct { maxJoinAttempts int // specifies the join retry interval joinRetryInterval time.Duration + joinTimeout time.Duration // specifies the discovery provider provider discovery.Provider // specifies the node client port @@ -89,9 +90,10 @@ func NewConfig() *Config { maxJoinAttempts: 5, joinRetryInterval: time.Second, shutdownTimeout: 3 * time.Second, + joinTimeout: time.Minute, syncInterval: time.Minute, logger: log.New(log.ErrorLevel, os.Stderr), - readTimeout: time.Second, + readTimeout: time.Minute, } } @@ -143,7 +145,12 @@ func (config *Config) WithHost(host string) *Config { return config } -// WithSyncInterval sets the delegate sync interval +// WithSyncInterval sets the cluster synchronization interval. +// This is the interval between complete states synchronization between nodes. +// Complete states synchronization are done with a single node over TCP and are +// quite expensive relative to standard gossiped messages. +// Setting this interval lower (more frequent) will increase convergence +// speeds across larger clusters at the expense of increased bandwidth usage. func (config *Config) WithSyncInterval(interval time.Duration) *Config { config.syncInterval = interval return config @@ -151,6 +158,7 @@ func (config *Config) WithSyncInterval(interval time.Duration) *Config { // WithReadTimeout sets the Node read timeout. // This timeout specifies the timeout of a data retrieval +// The read timeout should be either greater or equal to syncInterval func (config *Config) WithReadTimeout(timeout time.Duration) *Config { config.readTimeout = timeout return config @@ -186,8 +194,11 @@ func (config *Config) Validate() error { AddAssertion(config.provider != nil, "discovery provider is not set"). AddAssertion(config.joinRetryInterval > 0, "join retry interval is invalid"). AddAssertion(config.shutdownTimeout > 0, "shutdown timeout is invalid"). + AddAssertion(config.joinTimeout > 0, "join timeout is invalid"). AddAssertion(config.maxJoinAttempts > 0, "max join attempts is invalid"). AddAssertion(config.syncInterval > 0, "stateSync interval is invalid"). + AddAssertion(config.readTimeout > 0, "read timeout is invalid"). + AddAssertion(config.joinTimeout > config.joinRetryInterval, "join timeout must greater than join retry interval"). AddValidator(validation.NewEmptyStringValidator("host", config.host)). AddValidator(validation.NewConditionalValidator(len(config.secretKeys) != 0, validation.NewEmptyStringValidator("config.cookie", config.cookie))). diff --git a/internal/tcp/tcp.go b/internal/tcp/tcp.go index c8e3b20..d7c40e2 100644 --- a/internal/tcp/tcp.go +++ b/internal/tcp/tcp.go @@ -24,7 +24,12 @@ package tcp -import "net" +import ( + "fmt" + "net" + + "github.com/hashicorp/go-sockaddr" +) // GetHostPort returns the actual ip address and port from a given address func GetHostPort(address string) (string, int, error) { @@ -36,3 +41,39 @@ func GetHostPort(address string) (string, int, error) { return addr.IP.String(), addr.Port, nil } + +// GetBindIP tries to find an appropriate bindIP to bind and propagate. +func GetBindIP(address string) (string, error) { + bindIP, _, err := GetHostPort(address) + if err != nil { + return "", fmt.Errorf("invalid address: %w", err) + } + + if bindIP == "0.0.0.0" { + // if we're not bound to a specific IP, let's use a suitable private IP address. + ipStr, err := sockaddr.GetPrivateIP() + if err != nil { + return "", fmt.Errorf("failed to get private interface addresses: %w", err) + } + + // if we could not find a private address, we need to expand our search to a public + // ip address + if ipStr == "" { + ipStr, err = sockaddr.GetPublicIP() + if err != nil { + return "", fmt.Errorf("failed to get public interface addresses: %w", err) + } + } + + if ipStr == "" { + return "", fmt.Errorf("no private IP address found, and explicit IP not provided") + } + + parsed := net.ParseIP(ipStr) + if parsed == nil { + return "", fmt.Errorf("failed to parse private IP address: %q", ipStr) + } + bindIP = parsed.String() + } + return bindIP, nil +} diff --git a/node.go b/node.go index 8200d71..b5aa73a 100644 --- a/node.go +++ b/node.go @@ -38,6 +38,7 @@ import ( "time" "connectrpc.com/connect" + "github.com/flowchartsman/retry" "github.com/hashicorp/memberlist" "go.uber.org/atomic" "google.golang.org/protobuf/types/known/timestamppb" @@ -156,7 +157,7 @@ func (node *Node) Start(ctx context.Context) error { AddError(node.config.Validate()). AddError(node.config.provider.Initialize()). AddError(node.config.provider.Register()). - AddError(node.join()). + AddError(node.join(ctx)). AddError(node.serve(ctx)). Error(); err != nil { node.mu.Unlock() @@ -231,7 +232,6 @@ func (node *Node) Put(ctx context.Context, request *connect.Request[internalpb.P } // Get is used to retrieve a key/value pair in a cluster of nodes -// nolint func (node *Node) Get(ctx context.Context, request *connect.Request[internalpb.GetRequest]) (*connect.Response[internalpb.GetResponse], error) { node.mu.Lock() if !node.started.Load() { @@ -239,9 +239,28 @@ func (node *Node) Get(ctx context.Context, request *connect.Request[internalpb.G return nil, connect.NewError(connect.CodeFailedPrecondition, ErrNodeNotStarted) } + ctx, cancelFn := context.WithTimeout(ctx, node.config.readTimeout) + defer cancelFn() + req := request.Msg - entry, err := node.delegate.Get(req.GetKey()) - if err != nil { + var ( + rerr error + entry *internalpb.Entry + ) + + retrier := retry.NewRetrier(2, node.config.readTimeout, node.config.syncInterval) + if err := retrier.RunContext(ctx, func(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + entry, rerr = node.delegate.Get(req.GetKey()) + if rerr != nil { + return rerr + } + } + return nil + }); err != nil { node.mu.Unlock() return nil, connect.NewError(connect.CodeNotFound, err) } @@ -342,13 +361,14 @@ func (node *Node) Peers() ([]*Member, error) { // serve start the underlying http server func (node *Node) serve(ctx context.Context) error { // extract the actual TCP ip discoveryAddress - host, port, err := tcp.GetHostPort(fmt.Sprintf("%s:%d", node.config.host, node.config.port)) + hostPort := net.JoinHostPort(node.config.host, strconv.Itoa(int(node.config.port))) + bindIP, err := tcp.GetBindIP(hostPort) if err != nil { return fmt.Errorf("failed to resolve TCP discoveryAddress: %w", err) } - node.config.WithHost(host) - node.config.WithPort(uint16(port)) + node.config.WithHost(bindIP) + node.config.WithPort(uint16(node.config.port)) // hook the node as the KV service handler // TODO: add metric options to the handler @@ -372,20 +392,29 @@ func (node *Node) serve(ctx context.Context) error { } // join attempts to join an existing cluster if node peers is provided -func (node *Node) join() error { +func (node *Node) join(ctx context.Context) error { mlist, err := memberlist.Create(node.memberConfig) if err != nil { node.config.logger.Error(fmt.Errorf("failed to create memberlist: %w", err)) return err } - // TODO: use a retry mechanism here - peers, err := node.config.provider.DiscoverPeers() - if err != nil { - node.config.logger.Error(fmt.Errorf("failed to discover peers: %w", err)) + ctx2, cancel := context.WithTimeout(ctx, node.config.joinTimeout) + var peers []string + retrier := retry.NewRetrier(node.config.maxJoinAttempts, node.config.joinRetryInterval, node.config.joinRetryInterval) + if err := retrier.RunContext(ctx2, func(ctx context.Context) error { // nolint + peers, err = node.config.provider.DiscoverPeers() + if err != nil { + return err + } + return nil + }); err != nil { + cancel() return err } + cancel() + // set the mlist node.memberlist = mlist if len(peers) > 0 { diff --git a/node_test.go b/node_test.go index a71e731..6e72f80 100644 --- a/node_test.go +++ b/node_test.go @@ -251,6 +251,8 @@ func startNode(t *testing.T, serverAddr string) (*Node, discovery.Provider) { host: host, syncInterval: 500 * time.Millisecond, joinRetryInterval: 500 * time.Millisecond, + joinTimeout: time.Second, + readTimeout: 500 * time.Millisecond, maxJoinAttempts: 5, cookie: cookie, secretKeys: []string{b64},