diff --git a/multiple.go b/multiple.go index 47f3f11..2759e2d 100644 --- a/multiple.go +++ b/multiple.go @@ -167,7 +167,11 @@ func (m *multipleSqlConn) containSelect(query string) bool { return false } -func (m *multipleSqlConn) getQueryDB(query string) queryDB { +func (m *multipleSqlConn) getQueryDB(ctx context.Context, query string) queryDB { + if forceLeaderFromContext(ctx) { + return queryDB{conn: m.leader} + } + if !m.enableFollower { return queryDB{conn: m.leader} } @@ -242,7 +246,7 @@ func (m *multipleSqlConn) startSpanWithFollower(ctx context.Context, db int) (co } func (m *multipleSqlConn) query(ctx context.Context, query string, do func(ctx context.Context, conn sqlx.SqlConn) error) error { - db := m.getQueryDB(query) + db := m.getQueryDB(ctx, query) var span oteltrace.Span if db.follower { ctx, span = m.startSpanWithFollower(ctx, db.followerDB) @@ -297,3 +301,15 @@ func WithAccept(accept func(err error) bool) SqlOption { conn.accept = accept } } + +type forceLeaderKey struct{} + +func ForceLeaderContext(ctx context.Context) context.Context { + return context.WithValue(ctx, forceLeaderKey{}, struct{}{}) +} + +func forceLeaderFromContext(ctx context.Context) bool { + value := ctx.Value(forceLeaderKey{}) + _, ok := value.(struct{}) + return ok +} diff --git a/multiple_test.go b/multiple_test.go index 40b9e21..4275939 100644 --- a/multiple_test.go +++ b/multiple_test.go @@ -1,6 +1,7 @@ package sqlx import ( + "context" "database/sql/driver" "testing" "time" @@ -52,3 +53,10 @@ func TestNewMultipleSqlConn(t *testing.T) { assert.NoError(t, err) assert.Equal(t, int64(1), rowsAffected) } + +func TestForceLeaderContext(t *testing.T) { + ctx := ForceLeaderContext(context.Background()) + assert.True(t, forceLeaderFromContext(ctx)) + + assert.False(t, forceLeaderFromContext(context.Background())) +}