Skip to content

Commit

Permalink
refactor: refactor the implementation (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tochemey authored Dec 18, 2024
1 parent 35903bb commit 897f5ce
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 15 deletions.
15 changes: 13 additions & 2 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type Config struct {
maxJoinAttempts int
// specifies the join retry interval
joinRetryInterval time.Duration
joinTimeout time.Duration
// specifies the discovery provider
provider discovery.Provider
// specifies the node client port
Expand Down Expand Up @@ -89,9 +90,10 @@ func NewConfig() *Config {
maxJoinAttempts: 5,
joinRetryInterval: time.Second,
shutdownTimeout: 3 * time.Second,
joinTimeout: time.Minute,
syncInterval: time.Minute,
logger: log.New(log.ErrorLevel, os.Stderr),
readTimeout: time.Second,
readTimeout: time.Minute,
}
}

Expand Down Expand Up @@ -143,14 +145,20 @@ func (config *Config) WithHost(host string) *Config {
return config
}

// WithSyncInterval sets the delegate sync interval
// WithSyncInterval sets the cluster synchronization interval.
// This is the interval between complete states synchronization between nodes.
// Complete states synchronization are done with a single node over TCP and are
// quite expensive relative to standard gossiped messages.
// Setting this interval lower (more frequent) will increase convergence
// speeds across larger clusters at the expense of increased bandwidth usage.
func (config *Config) WithSyncInterval(interval time.Duration) *Config {
config.syncInterval = interval
return config
}

// WithReadTimeout sets the Node read timeout.
// This timeout specifies the timeout of a data retrieval
// The read timeout should be either greater or equal to syncInterval
func (config *Config) WithReadTimeout(timeout time.Duration) *Config {
config.readTimeout = timeout
return config
Expand Down Expand Up @@ -186,8 +194,11 @@ func (config *Config) Validate() error {
AddAssertion(config.provider != nil, "discovery provider is not set").
AddAssertion(config.joinRetryInterval > 0, "join retry interval is invalid").
AddAssertion(config.shutdownTimeout > 0, "shutdown timeout is invalid").
AddAssertion(config.joinTimeout > 0, "join timeout is invalid").
AddAssertion(config.maxJoinAttempts > 0, "max join attempts is invalid").
AddAssertion(config.syncInterval > 0, "stateSync interval is invalid").
AddAssertion(config.readTimeout > 0, "read timeout is invalid").
AddAssertion(config.joinTimeout > config.joinRetryInterval, "join timeout must greater than join retry interval").
AddValidator(validation.NewEmptyStringValidator("host", config.host)).
AddValidator(validation.NewConditionalValidator(len(config.secretKeys) != 0,
validation.NewEmptyStringValidator("config.cookie", config.cookie))).
Expand Down
43 changes: 42 additions & 1 deletion internal/tcp/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@

package tcp

import "net"
import (
"fmt"
"net"

"github.com/hashicorp/go-sockaddr"
)

// GetHostPort returns the actual ip address and port from a given address
func GetHostPort(address string) (string, int, error) {
Expand All @@ -36,3 +41,39 @@ func GetHostPort(address string) (string, int, error) {

return addr.IP.String(), addr.Port, nil
}

// GetBindIP tries to find an appropriate bindIP to bind and propagate.
func GetBindIP(address string) (string, error) {
bindIP, _, err := GetHostPort(address)
if err != nil {
return "", fmt.Errorf("invalid address: %w", err)
}

if bindIP == "0.0.0.0" {
// if we're not bound to a specific IP, let's use a suitable private IP address.
ipStr, err := sockaddr.GetPrivateIP()
if err != nil {
return "", fmt.Errorf("failed to get private interface addresses: %w", err)
}

// if we could not find a private address, we need to expand our search to a public
// ip address
if ipStr == "" {
ipStr, err = sockaddr.GetPublicIP()
if err != nil {
return "", fmt.Errorf("failed to get public interface addresses: %w", err)
}
}

if ipStr == "" {
return "", fmt.Errorf("no private IP address found, and explicit IP not provided")
}

parsed := net.ParseIP(ipStr)
if parsed == nil {
return "", fmt.Errorf("failed to parse private IP address: %q", ipStr)
}
bindIP = parsed.String()
}
return bindIP, nil
}
53 changes: 41 additions & 12 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"time"

"connectrpc.com/connect"
"github.com/flowchartsman/retry"
"github.com/hashicorp/memberlist"
"go.uber.org/atomic"
"google.golang.org/protobuf/types/known/timestamppb"
Expand Down Expand Up @@ -156,7 +157,7 @@ func (node *Node) Start(ctx context.Context) error {
AddError(node.config.Validate()).
AddError(node.config.provider.Initialize()).
AddError(node.config.provider.Register()).
AddError(node.join()).
AddError(node.join(ctx)).
AddError(node.serve(ctx)).
Error(); err != nil {
node.mu.Unlock()
Expand Down Expand Up @@ -231,17 +232,35 @@ func (node *Node) Put(ctx context.Context, request *connect.Request[internalpb.P
}

// Get is used to retrieve a key/value pair in a cluster of nodes
// nolint
func (node *Node) Get(ctx context.Context, request *connect.Request[internalpb.GetRequest]) (*connect.Response[internalpb.GetResponse], error) {
node.mu.Lock()
if !node.started.Load() {
node.mu.Unlock()
return nil, connect.NewError(connect.CodeFailedPrecondition, ErrNodeNotStarted)
}

ctx, cancelFn := context.WithTimeout(ctx, node.config.readTimeout)
defer cancelFn()

req := request.Msg
entry, err := node.delegate.Get(req.GetKey())
if err != nil {
var (
rerr error
entry *internalpb.Entry
)

retrier := retry.NewRetrier(2, node.config.readTimeout, node.config.syncInterval)
if err := retrier.RunContext(ctx, func(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
entry, rerr = node.delegate.Get(req.GetKey())
if rerr != nil {
return rerr
}
}
return nil
}); err != nil {
node.mu.Unlock()
return nil, connect.NewError(connect.CodeNotFound, err)
}
Expand Down Expand Up @@ -342,13 +361,14 @@ func (node *Node) Peers() ([]*Member, error) {
// serve start the underlying http server
func (node *Node) serve(ctx context.Context) error {
// extract the actual TCP ip discoveryAddress
host, port, err := tcp.GetHostPort(fmt.Sprintf("%s:%d", node.config.host, node.config.port))
hostPort := net.JoinHostPort(node.config.host, strconv.Itoa(int(node.config.port)))
bindIP, err := tcp.GetBindIP(hostPort)
if err != nil {
return fmt.Errorf("failed to resolve TCP discoveryAddress: %w", err)
}

node.config.WithHost(host)
node.config.WithPort(uint16(port))
node.config.WithHost(bindIP)
node.config.WithPort(uint16(node.config.port))

// hook the node as the KV service handler
// TODO: add metric options to the handler
Expand All @@ -372,20 +392,29 @@ func (node *Node) serve(ctx context.Context) error {
}

// join attempts to join an existing cluster if node peers is provided
func (node *Node) join() error {
func (node *Node) join(ctx context.Context) error {
mlist, err := memberlist.Create(node.memberConfig)
if err != nil {
node.config.logger.Error(fmt.Errorf("failed to create memberlist: %w", err))
return err
}

// TODO: use a retry mechanism here
peers, err := node.config.provider.DiscoverPeers()
if err != nil {
node.config.logger.Error(fmt.Errorf("failed to discover peers: %w", err))
ctx2, cancel := context.WithTimeout(ctx, node.config.joinTimeout)
var peers []string
retrier := retry.NewRetrier(node.config.maxJoinAttempts, node.config.joinRetryInterval, node.config.joinRetryInterval)
if err := retrier.RunContext(ctx2, func(ctx context.Context) error { // nolint
peers, err = node.config.provider.DiscoverPeers()
if err != nil {
return err
}
return nil
}); err != nil {
cancel()
return err
}

cancel()

// set the mlist
node.memberlist = mlist
if len(peers) > 0 {
Expand Down
2 changes: 2 additions & 0 deletions node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ func startNode(t *testing.T, serverAddr string) (*Node, discovery.Provider) {
host: host,
syncInterval: 500 * time.Millisecond,
joinRetryInterval: 500 * time.Millisecond,
joinTimeout: time.Second,
readTimeout: 500 * time.Millisecond,
maxJoinAttempts: 5,
cookie: cookie,
secretKeys: []string{b64},
Expand Down

0 comments on commit 897f5ce

Please sign in to comment.