Skip to content

Commit

Permalink
feat: support for enforcing the use of the leader db (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenquan authored Jul 8, 2023
1 parent 0065098 commit 1e2c0dd
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
20 changes: 18 additions & 2 deletions multiple.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,11 @@ func (m *multipleSqlConn) containSelect(query string) bool {
return false
}

func (m *multipleSqlConn) getQueryDB(query string) queryDB {
func (m *multipleSqlConn) getQueryDB(ctx context.Context, query string) queryDB {
if forceLeaderFromContext(ctx) {
return queryDB{conn: m.leader}
}

if !m.enableFollower {
return queryDB{conn: m.leader}
}
Expand Down Expand Up @@ -242,7 +246,7 @@ func (m *multipleSqlConn) startSpanWithFollower(ctx context.Context, db int) (co
}

func (m *multipleSqlConn) query(ctx context.Context, query string, do func(ctx context.Context, conn sqlx.SqlConn) error) error {
db := m.getQueryDB(query)
db := m.getQueryDB(ctx, query)
var span oteltrace.Span
if db.follower {
ctx, span = m.startSpanWithFollower(ctx, db.followerDB)
Expand Down Expand Up @@ -297,3 +301,15 @@ func WithAccept(accept func(err error) bool) SqlOption {
conn.accept = accept
}
}

type forceLeaderKey struct{}

func ForceLeaderContext(ctx context.Context) context.Context {
return context.WithValue(ctx, forceLeaderKey{}, struct{}{})
}

func forceLeaderFromContext(ctx context.Context) bool {
value := ctx.Value(forceLeaderKey{})
_, ok := value.(struct{})
return ok
}
8 changes: 8 additions & 0 deletions multiple_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sqlx

import (
"context"
"database/sql/driver"
"testing"
"time"
Expand Down Expand Up @@ -52,3 +53,10 @@ func TestNewMultipleSqlConn(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, int64(1), rowsAffected)
}

func TestForceLeaderContext(t *testing.T) {
ctx := ForceLeaderContext(context.Background())
assert.True(t, forceLeaderFromContext(ctx))

assert.False(t, forceLeaderFromContext(context.Background()))
}

0 comments on commit 1e2c0dd

Please sign in to comment.