diff --git a/common/clock/event_time_source.go b/common/clock/event_time_source.go index da56c25d432..7254d50820f 100644 --- a/common/clock/event_time_source.go +++ b/common/clock/event_time_source.go @@ -27,6 +27,8 @@ package clock import ( "sync" "time" + + "go.temporal.io/server/common/util" ) type ( @@ -135,6 +137,23 @@ func (ts *EventTimeSource) Advance(d time.Duration) { ts.fireTimers() } +// AdvanceNext advances to the next timer. +func (ts *EventTimeSource) AdvanceNext() { + ts.mu.Lock() + defer ts.mu.Unlock() + + if len(ts.timers) == 0 { + return + } + // just do linear search, this is efficient enough for now + tmin := ts.timers[0].deadline + for _, t := range ts.timers[1:] { + tmin = util.MinTime(tmin, t.deadline) + } + ts.now = tmin + ts.fireTimers() +} + // NumTimers returns the number of outstanding timers. func (ts *EventTimeSource) NumTimers() int { ts.mu.Lock() @@ -143,6 +162,12 @@ func (ts *EventTimeSource) NumTimers() int { return len(ts.timers) } +// Sleep is a convenience function for waiting on a new timer. +func (ts *EventTimeSource) Sleep(d time.Duration) { + t, _ := ts.NewTimer(d) + <-t +} + // fireTimers fires all timers that are ready. func (ts *EventTimeSource) fireTimers() { n := 0 diff --git a/common/persistence/cassandra/matching_task_store.go b/common/persistence/cassandra/matching_task_store.go index 43286fd8b52..5d1039efcff 100644 --- a/common/persistence/cassandra/matching_task_store.go +++ b/common/persistence/cassandra/matching_task_store.go @@ -522,53 +522,61 @@ func (d *MatchingTaskStore) UpdateTaskQueueUserData( ) error { batch := d.Session.NewBatch(gocql.UnloggedBatch).WithContext(ctx) - if request.Version == 0 { - batch.Query(templateInsertTaskQueueUserDataQuery, - request.NamespaceID, - request.TaskQueue, - request.UserData.Data, - request.UserData.EncodingType.String(), - ) - } else { - batch.Query(templateUpdateTaskQueueUserDataQuery, - request.UserData.Data, - request.UserData.EncodingType.String(), - request.Version+1, - request.NamespaceID, - request.TaskQueue, - request.Version, - ) - } - for _, buildId := range request.BuildIdsAdded { - batch.Query(templateInsertBuildIdTaskQueueMappingQuery, request.NamespaceID, buildId, request.TaskQueue) - } - for _, buildId := range request.BuildIdsRemoved { - batch.Query(templateDeleteBuildIdTaskQueueMappingQuery, request.NamespaceID, buildId, request.TaskQueue) + for taskQueue, update := range request.Updates { + if update.Version == 0 { + batch.Query(templateInsertTaskQueueUserDataQuery, + request.NamespaceID, + taskQueue, + update.UserData.Data, + update.UserData.EncodingType.String(), + ) + } else { + batch.Query(templateUpdateTaskQueueUserDataQuery, + update.UserData.Data, + update.UserData.EncodingType.String(), + update.Version+1, + request.NamespaceID, + taskQueue, + update.Version, + ) + } + for _, buildId := range update.BuildIdsAdded { + batch.Query(templateInsertBuildIdTaskQueueMappingQuery, request.NamespaceID, buildId, taskQueue) + } + for _, buildId := range update.BuildIdsRemoved { + batch.Query(templateDeleteBuildIdTaskQueueMappingQuery, request.NamespaceID, buildId, taskQueue) + } } - previous := make(map[string]interface{}) + previous := make(map[string]any) applied, iter, err := d.Session.MapExecuteBatchCAS(batch, previous) - - if err != nil { - return gocql.ConvertError("UpdateTaskQueueUserData", err) - } - - // We only care about the conflict in the first query - err = iter.Close() if err != nil { return gocql.ConvertError("UpdateTaskQueueUserData", err) } + defer iter.Close() if !applied { - var columns []string - for k, v := range previous { - columns = append(columns, fmt.Sprintf("%s=%v", k, v)) - } - - return &p.ConditionFailedError{ - Msg: fmt.Sprintf("Failed to update task queue. name: %v, version: %v, columns: (%v)", - request.TaskQueue, request.Version, strings.Join(columns, ",")), + // No error, but not applied. That means we had a conflict. + // Iterate through results to identify first conflicting row. + for { + name, nameErr := getTypedFieldFromRow[string]("task_queue_name", previous) + previousVersion, verErr := getTypedFieldFromRow[int64]("version", previous) + update, hasUpdate := request.Updates[name] + if nameErr == nil && verErr == nil && hasUpdate && update.Version != previousVersion { + if update.Conflicting != nil { + *update.Conflicting = true + } + return &p.ConditionFailedError{ + Msg: fmt.Sprintf("Failed to update task queues: task queue %q version %d != %d", + name, update.Version, previousVersion), + } + } + clear(previous) + if !iter.MapScan(previous) { + break + } } + return &p.ConditionFailedError{Msg: "Failed to update task queues: unknown conflict"} } return nil diff --git a/common/persistence/data_interfaces.go b/common/persistence/data_interfaces.go index 81715c93fd1..0023eaf68d6 100644 --- a/common/persistence/data_interfaces.go +++ b/common/persistence/data_interfaces.go @@ -523,13 +523,23 @@ type ( UserData *persistencespb.VersionedTaskQueueUserData } - // UpdateTaskQueueUserDataRequest is the input type for the UpdateTaskQueueUserData API + // UpdateTaskQueueUserDataRequest is the input type for the UpdateTaskQueueUserData API. + // This updates user data for multiple task queues in one namespace. UpdateTaskQueueUserDataRequest struct { - NamespaceID string - TaskQueue string + NamespaceID string + Updates map[string]*SingleTaskQueueUserDataUpdate // key is task queue name + } + + SingleTaskQueueUserDataUpdate struct { UserData *persistencespb.VersionedTaskQueueUserData BuildIdsAdded []string BuildIdsRemoved []string + // If Conflicting is non-nil, and this single update fails due to a version conflict, + // then it will be set to true. Conflicting updates should not be retried. + // Note that even if Conflicting is not set to true, the update may still be + // conflicting, because persistence implementations may only be able to identify the + // first conflict in a set. + Conflicting *bool } ListTaskQueueUserDataEntriesRequest struct { diff --git a/common/persistence/persistence_interface.go b/common/persistence/persistence_interface.go index faa383b4b78..355f77ccb99 100644 --- a/common/persistence/persistence_interface.go +++ b/common/persistence/persistence_interface.go @@ -289,12 +289,16 @@ type ( InternalUpdateTaskQueueUserDataRequest struct { NamespaceID string - TaskQueue string - Version int64 - UserData *commonpb.DataBlob + Updates map[string]*InternalSingleTaskQueueUserDataUpdate // key is task queue name + } + + InternalSingleTaskQueueUserDataUpdate struct { + Version int64 + UserData *commonpb.DataBlob // Used to build an index of build_id to task_queues BuildIdsAdded []string BuildIdsRemoved []string + Conflicting *bool } InternalTaskQueueUserDataEntry struct { diff --git a/common/persistence/sql/task.go b/common/persistence/sql/task.go index ad76d05ee3d..5e0f3b3ee18 100644 --- a/common/persistence/sql/task.go +++ b/common/persistence/sql/task.go @@ -481,38 +481,43 @@ func (m *sqlTaskManager) UpdateTaskQueueUserData(ctx context.Context, request *p return serviceerror.NewInternal(fmt.Sprintf("failed to parse namespace ID as UUID: %v", err)) } err = m.txExecute(ctx, "UpdateTaskQueueUserData", func(tx sqlplugin.Tx) error { - err := tx.UpdateTaskQueueUserData(ctx, &sqlplugin.UpdateTaskQueueDataRequest{ - NamespaceID: namespaceID, - TaskQueueName: request.TaskQueue, - Data: request.UserData.Data, - DataEncoding: request.UserData.EncodingType.String(), - Version: request.Version, - }) - if m.Db.IsDupEntryError(err) { - return &persistence.ConditionFailedError{Msg: err.Error()} - } - if err != nil { - return err - } - if len(request.BuildIdsAdded) > 0 { - err = tx.AddToBuildIdToTaskQueueMapping(ctx, sqlplugin.AddToBuildIdToTaskQueueMapping{ + for taskQueue, update := range request.Updates { + err := tx.UpdateTaskQueueUserData(ctx, &sqlplugin.UpdateTaskQueueDataRequest{ NamespaceID: namespaceID, - TaskQueueName: request.TaskQueue, - BuildIds: request.BuildIdsAdded, + TaskQueueName: taskQueue, + Data: update.UserData.Data, + DataEncoding: update.UserData.EncodingType.String(), + Version: update.Version, }) - if err != nil { - return err + if m.Db.IsDupEntryError(err) { + err = &persistence.ConditionFailedError{Msg: err.Error()} + } + if persistence.IsConflictErr(err) && update.Conflicting != nil { + *update.Conflicting = true } - } - if len(request.BuildIdsRemoved) > 0 { - err = tx.RemoveFromBuildIdToTaskQueueMapping(ctx, sqlplugin.RemoveFromBuildIdToTaskQueueMapping{ - NamespaceID: namespaceID, - TaskQueueName: request.TaskQueue, - BuildIds: request.BuildIdsRemoved, - }) if err != nil { return err } + if len(update.BuildIdsAdded) > 0 { + err = tx.AddToBuildIdToTaskQueueMapping(ctx, sqlplugin.AddToBuildIdToTaskQueueMapping{ + NamespaceID: namespaceID, + TaskQueueName: taskQueue, + BuildIds: update.BuildIdsAdded, + }) + if err != nil { + return err + } + } + if len(update.BuildIdsRemoved) > 0 { + err = tx.RemoveFromBuildIdToTaskQueueMapping(ctx, sqlplugin.RemoveFromBuildIdToTaskQueueMapping{ + NamespaceID: namespaceID, + TaskQueueName: taskQueue, + BuildIds: update.BuildIdsRemoved, + }) + if err != nil { + return err + } + } } return nil }) diff --git a/common/persistence/task_manager.go b/common/persistence/task_manager.go index ef24f0a8163..139f9e2a3e0 100644 --- a/common/persistence/task_manager.go +++ b/common/persistence/task_manager.go @@ -256,17 +256,22 @@ func (m *taskManagerImpl) GetTaskQueueUserData(ctx context.Context, request *Get // UpdateTaskQueueUserData implements TaskManager func (m *taskManagerImpl) UpdateTaskQueueUserData(ctx context.Context, request *UpdateTaskQueueUserDataRequest) error { - userData, err := m.serializer.TaskQueueUserDataToBlob(request.UserData.Data, enumspb.ENCODING_TYPE_PROTO3) - if err != nil { - return err - } internalRequest := &InternalUpdateTaskQueueUserDataRequest{ - NamespaceID: request.NamespaceID, - TaskQueue: request.TaskQueue, - Version: request.UserData.Version, - UserData: userData, - BuildIdsAdded: request.BuildIdsAdded, - BuildIdsRemoved: request.BuildIdsRemoved, + NamespaceID: request.NamespaceID, + Updates: make(map[string]*InternalSingleTaskQueueUserDataUpdate, len(request.Updates)), + } + for taskQueue, update := range request.Updates { + userData, err := m.serializer.TaskQueueUserDataToBlob(update.UserData.Data, enumspb.ENCODING_TYPE_PROTO3) + if err != nil { + return err + } + internalRequest.Updates[taskQueue] = &InternalSingleTaskQueueUserDataUpdate{ + Version: update.UserData.Version, + UserData: userData, + BuildIdsAdded: update.BuildIdsAdded, + BuildIdsRemoved: update.BuildIdsRemoved, + Conflicting: update.Conflicting, + } } return m.taskStore.UpdateTaskQueueUserData(ctx, internalRequest) } diff --git a/common/persistence/tests/cassandra_test.go b/common/persistence/tests/cassandra_test.go index c8b359ac464..0c68152ea5a 100644 --- a/common/persistence/tests/cassandra_test.go +++ b/common/persistence/tests/cassandra_test.go @@ -249,6 +249,19 @@ func TestCassandraTaskQueueTaskSuite(t *testing.T) { suite.Run(t, s) } +func TestCassandraTaskQueueUserDataSuite(t *testing.T) { + testData, tearDown := setUpCassandraTest(t) + defer tearDown() + + taskQueueStore, err := testData.Factory.NewTaskStore() + if err != nil { + t.Fatalf("unable to create Cassandra DB: %v", err) + } + + s := NewTaskQueueUserDataSuite(t, taskQueueStore, testData.Logger) + suite.Run(t, s) +} + func TestCassandraHistoryV2Persistence(t *testing.T) { s := new(persistencetests.HistoryV2PersistenceSuite) s.TestBase = persistencetests.NewTestBaseWithCassandra(&persistencetests.TestBaseOptions{}) diff --git a/common/persistence/tests/mysql_test.go b/common/persistence/tests/mysql_test.go index 5d98146adac..74cf2440321 100644 --- a/common/persistence/tests/mysql_test.go +++ b/common/persistence/tests/mysql_test.go @@ -149,6 +149,19 @@ func TestMySQLTaskQueueTaskSuite(t *testing.T) { suite.Run(t, s) } +func TestMySQLTaskQueueUserDataSuite(t *testing.T) { + testData, tearDown := setUpMySQLTest(t) + defer tearDown() + + taskQueueStore, err := testData.Factory.NewTaskStore() + if err != nil { + t.Fatalf("unable to create MySQL DB: %v", err) + } + + s := NewTaskQueueUserDataSuite(t, taskQueueStore, testData.Logger) + suite.Run(t, s) +} + func TestMySQLVisibilityPersistenceSuite(t *testing.T) { s := &VisibilityPersistenceSuite{ TestBase: persistencetests.NewTestBaseWithSQL(persistencetests.GetMySQLTestClusterOption()), diff --git a/common/persistence/tests/postgresql_test.go b/common/persistence/tests/postgresql_test.go index 2e5bd1b97b2..20646449192 100644 --- a/common/persistence/tests/postgresql_test.go +++ b/common/persistence/tests/postgresql_test.go @@ -149,6 +149,19 @@ func (p *PostgreSQLSuite) TestPostgreSQLTaskQueueTaskSuite() { suite.Run(p.T(), s) } +func (p *PostgreSQLSuite) TestPostgreSQLTaskQueueUserDataSuite() { + testData, tearDown := setUpPostgreSQLTest(p.T(), p.pluginName) + defer tearDown() + + taskQueueStore, err := testData.Factory.NewTaskStore() + if err != nil { + p.T().Fatalf("unable to create PostgreSQL DB: %v", err) + } + + s := NewTaskQueueUserDataSuite(p.T(), taskQueueStore, testData.Logger) + suite.Run(p.T(), s) +} + func (p *PostgreSQLSuite) TestPostgreSQLVisibilityPersistenceSuite() { s := &VisibilityPersistenceSuite{ TestBase: persistencetests.NewTestBaseWithSQL(persistencetests.GetPostgreSQLTestClusterOption()), diff --git a/common/persistence/tests/sqlite_test.go b/common/persistence/tests/sqlite_test.go index d8b5377c501..7d09c7d4b3b 100644 --- a/common/persistence/tests/sqlite_test.go +++ b/common/persistence/tests/sqlite_test.go @@ -240,6 +240,28 @@ func TestSQLiteTaskQueueTaskSuite(t *testing.T) { suite.Run(t, s) } +func TestSQLiteTaskQueueUserDataSuite(t *testing.T) { + cfg := NewSQLiteMemoryConfig() + logger := log.NewNoopLogger() + factory := sql.NewFactory( + *cfg, + resolver.NewNoopResolver(), + testSQLiteClusterName, + logger, + metrics.NoopMetricsHandler, + ) + taskQueueStore, err := factory.NewTaskStore() + if err != nil { + t.Fatalf("unable to create SQLite DB: %v", err) + } + defer func() { + factory.Close() + }() + + s := NewTaskQueueUserDataSuite(t, taskQueueStore, logger) + suite.Run(t, s) +} + func TestSQLiteFileExecutionMutableStateStoreSuite(t *testing.T) { cfg := NewSQLiteFileConfig() SetupSQLiteDatabase(t, cfg) @@ -391,6 +413,32 @@ func TestSQLiteFileTaskQueueTaskSuite(t *testing.T) { suite.Run(t, s) } +func TestSQLiteFileTaskQueueUserDataSuite(t *testing.T) { + cfg := NewSQLiteFileConfig() + SetupSQLiteDatabase(t, cfg) + defer func() { + assert.NoError(t, os.Remove(cfg.DatabaseName)) + }() + logger := log.NewNoopLogger() + factory := sql.NewFactory( + *cfg, + resolver.NewNoopResolver(), + testSQLiteClusterName, + logger, + metrics.NoopMetricsHandler, + ) + taskQueueStore, err := factory.NewTaskStore() + if err != nil { + t.Fatalf("unable to create SQLite DB: %v", err) + } + defer func() { + factory.Close() + }() + + s := NewTaskQueueUserDataSuite(t, taskQueueStore, logger) + suite.Run(t, s) +} + // TODO: Merge persistence-tests into the tests directory. func TestSQLiteVisibilityPersistenceSuite(t *testing.T) { diff --git a/common/persistence/tests/task_queue_user_data.go b/common/persistence/tests/task_queue_user_data.go new file mode 100644 index 00000000000..7455960af63 --- /dev/null +++ b/common/persistence/tests/task_queue_user_data.go @@ -0,0 +1,206 @@ +// The MIT License +// +// Copyright (c) 2025 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package tests + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + persistencespb "go.temporal.io/server/api/persistence/v1" + "go.temporal.io/server/common/clock" + hlc "go.temporal.io/server/common/clock/hybrid_logical_clock" + "go.temporal.io/server/common/debug" + "go.temporal.io/server/common/log" + p "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/persistence/serialization" +) + +type ( + TaskQueueUserDataSuite struct { + suite.Suite + *require.Assertions + + namespaceID string + + taskManager p.TaskManager + logger log.Logger + + ctx context.Context + cancel context.CancelFunc + } +) + +func NewTaskQueueUserDataSuite( + t *testing.T, + taskStore p.TaskStore, + logger log.Logger, +) *TaskQueueUserDataSuite { + return &TaskQueueUserDataSuite{ + Assertions: require.New(t), + taskManager: p.NewTaskManager( + taskStore, + serialization.NewSerializer(), + ), + logger: logger, + } +} + +func (s *TaskQueueUserDataSuite) SetupTest() { + s.Assertions = require.New(s.T()) + s.ctx, s.cancel = context.WithTimeout(context.Background(), 30*time.Second*debug.TimeoutMultiplier) + + s.namespaceID = uuid.New().String() +} + +func (s *TaskQueueUserDataSuite) TearDownTest() { + s.cancel() +} + +func (s *TaskQueueUserDataSuite) TestSetInitialAndIncrement() { + tq1 := "tq1" + version := int64(0) // initial version must be 0 + + // check that get returns not found + _, err := s.taskManager.GetTaskQueueUserData(s.ctx, &p.GetTaskQueueUserDataRequest{ + NamespaceID: s.namespaceID, + TaskQueue: tq1, + }) + s.Error(err) + + // set initial user data + d1 := s.makeData(hlc.Zero(12345), version) + err = s.taskManager.UpdateTaskQueueUserData(s.ctx, &p.UpdateTaskQueueUserDataRequest{ + NamespaceID: s.namespaceID, + Updates: map[string]*p.SingleTaskQueueUserDataUpdate{ + tq1: &p.SingleTaskQueueUserDataUpdate{ + UserData: d1, + }, + }, + }) + s.NoError(err) + version++ + + // check that we can get it back + res, err := s.taskManager.GetTaskQueueUserData(s.ctx, &p.GetTaskQueueUserDataRequest{ + NamespaceID: s.namespaceID, + TaskQueue: tq1, + }) + s.NoError(err) + s.Equal(version, res.UserData.Version) + s.True(hlc.Equal(d1.Data.Clock, res.UserData.Data.Clock)) + + // increment it + d2 := s.makeData(d1.Data.Clock, version) + err = s.taskManager.UpdateTaskQueueUserData(s.ctx, &p.UpdateTaskQueueUserDataRequest{ + NamespaceID: s.namespaceID, + Updates: map[string]*p.SingleTaskQueueUserDataUpdate{ + tq1: &p.SingleTaskQueueUserDataUpdate{ + UserData: d2, + }, + }, + }) + s.NoError(err) + version++ + + // check that we can get it back + res, err = s.taskManager.GetTaskQueueUserData(s.ctx, &p.GetTaskQueueUserDataRequest{ + NamespaceID: s.namespaceID, + TaskQueue: tq1, + }) + s.NoError(err) + s.Equal(version, res.UserData.Version) + s.True(hlc.Equal(d2.Data.Clock, res.UserData.Data.Clock)) +} + +func (s *TaskQueueUserDataSuite) TestUpdateConflict() { + tq1, tq2, tq3 := "tq1", "tq2", "tq3" + + // set up three task queues + data := s.makeData(hlc.Zero(12345), 0) + for range 3 { + err := s.taskManager.UpdateTaskQueueUserData(s.ctx, &p.UpdateTaskQueueUserDataRequest{ + NamespaceID: s.namespaceID, + Updates: map[string]*p.SingleTaskQueueUserDataUpdate{ + tq1: &p.SingleTaskQueueUserDataUpdate{UserData: data}, + tq2: &p.SingleTaskQueueUserDataUpdate{UserData: data}, + tq3: &p.SingleTaskQueueUserDataUpdate{UserData: data}, + }, + }) + s.NoError(err) + data.Version++ + } + + // get all and verify + for _, tq := range []string{tq1, tq2, tq3} { + res, err := s.taskManager.GetTaskQueueUserData(s.ctx, &p.GetTaskQueueUserDataRequest{ + NamespaceID: s.namespaceID, + TaskQueue: tq, + }) + s.NoError(err) + s.Equal(int64(3), res.UserData.Version) + s.True(hlc.Equal(data.Data.Clock, res.UserData.Data.Clock)) + } + + // do update where one conflicts + d4 := s.makeData(data.Data.Clock, 4) + var conflict1, conflict2, conflict3 bool + err := s.taskManager.UpdateTaskQueueUserData(s.ctx, &p.UpdateTaskQueueUserDataRequest{ + NamespaceID: s.namespaceID, + Updates: map[string]*p.SingleTaskQueueUserDataUpdate{ + tq1: &p.SingleTaskQueueUserDataUpdate{UserData: data, Conflicting: &conflict1}, + tq2: &p.SingleTaskQueueUserDataUpdate{UserData: d4, Conflicting: &conflict2}, + tq3: &p.SingleTaskQueueUserDataUpdate{UserData: data, Conflicting: &conflict3}, + }, + }) + s.Error(err) + s.True(p.IsConflictErr(err)) + s.False(conflict1) + s.True(conflict2) + s.False(conflict3) + + // verify that none were updated + for _, tq := range []string{tq1, tq2, tq3} { + res, err := s.taskManager.GetTaskQueueUserData(s.ctx, &p.GetTaskQueueUserDataRequest{ + NamespaceID: s.namespaceID, + TaskQueue: tq, + }) + s.NoError(err) + s.Equal(int64(3), res.UserData.Version) + s.True(hlc.Equal(data.Data.Clock, res.UserData.Data.Clock)) + } +} + +func (s *TaskQueueUserDataSuite) makeData(prev *hlc.Clock, ver int64) *persistencespb.VersionedTaskQueueUserData { + return &persistencespb.VersionedTaskQueueUserData{ + Data: &persistencespb.TaskQueueUserData{ + Clock: hlc.Next(prev, clock.NewRealTimeSource()), + }, + Version: ver, + } +} diff --git a/common/stream_batcher/batcher.go b/common/stream_batcher/batcher.go new file mode 100644 index 00000000000..fad71937766 --- /dev/null +++ b/common/stream_batcher/batcher.go @@ -0,0 +1,176 @@ +// The MIT License +// +// Copyright (c) 2024 Temporal Technologies Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package stream_batcher + +import ( + "context" + "sync/atomic" + "time" + + "go.temporal.io/server/common/clock" +) + +// Batcher collects items concurrently passed to Add into batches and calls a processing +// function on them a batch at a time. +// The processing function will be called on batches of items in a single-threaded manner, and +// Add will block while fn is running. +type Batcher[T, R any] struct { + fn func([]T) R // batch executor function + opts BatcherOptions // timing/size options + timeSource clock.TimeSource // clock for testing + submitC chan batchPair[T, R] // channel for submitting items + // keeps track of goroutine state: + // if goroutine is not running, running == nil. + // if it is running, running points to a channel that will be closed when the goroutine exits. + running atomic.Pointer[chan struct{}] +} + +type batchPair[T, R any] struct { + resp chan R // response channel + item T // item to add +} + +type BatcherOptions struct { + // MaxItems is the maximum number of items in a batch. + MaxItems int + // MinDelay is how long to wait for no more items to come in after any item before + // finishing the batch. + MinDelay time.Duration + // MaxDelay is the maximum time to wait after the first item in a batch before finishing + // the batch. + MaxDelay time.Duration + // IdleTime is the time after which the internal goroutine will exit, to avoid wasting + // resources on idle streams. + IdleTime time.Duration +} + +// NewBatcher creates a Batcher. `fn` is the processing function, `opts` are the timing options. +// `clock` is usually clock.NewRealTimeSource but can be a fake time source for testing. +func NewBatcher[T, R any](fn func([]T) R, opts BatcherOptions, timeSource clock.TimeSource) *Batcher[T, R] { + return &Batcher[T, R]{ + fn: fn, + opts: opts, + timeSource: timeSource, + submitC: make(chan batchPair[T, R]), + } +} + +// Add adds an item to the stream and returns when it has been processed, or if the context is +// canceled or times out. It returns two values: the value that the batch processor returned, +// and a context error. Even if Add returns a context error, the item may still be processed in +// the future! +func (b *Batcher[T, R]) Add(ctx context.Context, t T) (R, error) { + resp := make(chan R) + pair := batchPair[T, R]{resp: resp, item: t} + + for { + runningC := b.running.Load() + for runningC == nil { + // goroutine is not running, try to start it + newRunningC := make(chan struct{}) + if b.running.CompareAndSwap(nil, &newRunningC) { + // we were the first one to notice the nil, start it now + go b.loop(&newRunningC) + } + // if CompareAndSwap failed, someone else was calling Add at the same time and + // started the goroutine already. reload to get the new running channel. + runningC = b.running.Load() + } + + select { + case <-(*runningC): + // we loaded a non-nil running channel, but it closed while we're waiting to + // submit. the goroutine must have just exited. try again. + continue + case b.submitC <- pair: + select { + case r := <-resp: + return r, nil + case <-ctx.Done(): + var zeroR R + return zeroR, ctx.Err() + } + case <-ctx.Done(): + var zeroR R + return zeroR, ctx.Err() + } + } +} + +func (b *Batcher[T, R]) loop(runningC *chan struct{}) { + defer func() { + // store nil so that Add knows it should start a goroutine + b.running.Store(nil) + // if Add loaded s.running after we decided to stop but before we Stored nil, so it + // thought we were running when we're not, then we need to wake it up so that it can + // start us again. + close(*runningC) + }() + + var items []T + var resps []chan R + for { + clear(items) + clear(resps) + items, resps = items[:0], resps[:0] + + // wait for first item. if no item after a while, exit the goroutine + idleC, idleT := b.timeSource.NewTimer(b.opts.IdleTime) + select { + case pair := <-b.submitC: + items = append(items, pair.item) + resps = append(resps, pair.resp) + case <-idleC: + return + } + idleT.Stop() + + // try to add more items. stop after a gap of MinDelay, total time of MaxDelay, + // or MaxItems items. + maxWaitC, maxWaitT := b.timeSource.NewTimer(b.opts.MaxDelay) + loop: + for len(items) < b.opts.MaxItems { + gapC, gapT := b.timeSource.NewTimer(b.opts.MinDelay) + select { + case pair := <-b.submitC: + items = append(items, pair.item) + resps = append(resps, pair.resp) + case <-gapC: + break loop + case <-maxWaitC: + gapT.Stop() + break loop + } + gapT.Stop() + } + maxWaitT.Stop() + + // process batch + r := b.fn(items) + + // send responses + for _, resp := range resps { + resp <- r + } + } +} diff --git a/common/stream_batcher/batcher_test.go b/common/stream_batcher/batcher_test.go new file mode 100644 index 00000000000..763a5c03800 --- /dev/null +++ b/common/stream_batcher/batcher_test.go @@ -0,0 +1,297 @@ +// The MIT License +// +// Copyright (c) 2024 Temporal Technologies Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package stream_batcher + +import ( + "context" + "math/rand" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.temporal.io/server/common/clock" +) + +// this test uses time.Sleep to allow goroutines to get into a blocked state +// +//nolint:forbidigo +func TestStreamBatcher_MinDelay(t *testing.T) { + clk := clock.NewEventTimeSource() + + opts := BatcherOptions{ + MaxItems: 10, + MinDelay: 100 * time.Millisecond, + MaxDelay: 400 * time.Millisecond, + IdleTime: 1000 * time.Millisecond, + } + process := func(items []int) (total int) { + for _, i := range items { + total += i + } + clk.Sleep(50 * time.Millisecond) + return + } + sb := NewBatcher(process, opts, clk) + + var wg sync.WaitGroup + wg.Add(3) + + // 350 = 200 (last Add call) + 100 (MinDelay) + 50 (process time) + targetMS := int64(350) + + go func() { + defer wg.Done() + clk.Sleep(100 * time.Millisecond) + ctx := context.Background() + total, err := sb.Add(ctx, 100) + assert.NoError(t, err) + assert.Equal(t, 123, total) + assert.Equal(t, targetMS, clk.Now().UnixMilli()) + }() + go func() { + defer wg.Done() + clk.Sleep(150 * time.Millisecond) + ctx := context.Background() + total, err := sb.Add(ctx, 20) + assert.NoError(t, err) + assert.Equal(t, 123, total) + assert.Equal(t, targetMS, clk.Now().UnixMilli()) + }() + go func() { + defer wg.Done() + clk.Sleep(200 * time.Millisecond) + ctx := context.Background() + total, err := sb.Add(ctx, 3) + assert.NoError(t, err) + assert.Equal(t, 123, total) + assert.Equal(t, targetMS, clk.Now().UnixMilli()) + }() + + time.Sleep(time.Millisecond) // wait for goroutines to get into Add + clk.AdvanceNext() // first Add + time.Sleep(time.Millisecond) + clk.AdvanceNext() // second Add + time.Sleep(time.Millisecond) + clk.AdvanceNext() // third add + time.Sleep(time.Millisecond) + clk.AdvanceNext() // min delay + time.Sleep(time.Millisecond) + clk.AdvanceNext() // process time + + wg.Wait() +} + +// this test uses time.Sleep to allow goroutines to get into a blocked state +// +//nolint:forbidigo +func TestStreamBatcher_MaxDelay(t *testing.T) { + clk := clock.NewEventTimeSource() + + opts := BatcherOptions{ + MaxItems: 10, + MinDelay: 100 * time.Millisecond, + MaxDelay: 120 * time.Millisecond, + IdleTime: 1000 * time.Millisecond, + } + process := func(items []int) (total int) { + for _, i := range items { + total += i + } + clk.Sleep(50 * time.Millisecond) + return + } + sb := NewBatcher(process, opts, clk) + + var wg sync.WaitGroup + wg.Add(4) + + // 270 = 220 (first Add call + max delay) + 50 (process time) + targetMS := int64(270) + + go func() { + defer wg.Done() + clk.Sleep(100 * time.Millisecond) + ctx := context.Background() + total, err := sb.Add(ctx, 100) + assert.NoError(t, err) + assert.Equal(t, 123, total) + assert.Equal(t, targetMS, clk.Now().UnixMilli()) + }() + go func() { + defer wg.Done() + clk.Sleep(150 * time.Millisecond) + ctx := context.Background() + total, err := sb.Add(ctx, 20) + assert.NoError(t, err) + assert.Equal(t, 123, total) + assert.Equal(t, targetMS, clk.Now().UnixMilli()) + }() + go func() { + defer wg.Done() + clk.Sleep(200 * time.Millisecond) + ctx := context.Background() + total, err := sb.Add(ctx, 3) + assert.NoError(t, err) + assert.Equal(t, 123, total) + assert.Equal(t, targetMS, clk.Now().UnixMilli()) + }() + go func() { + defer wg.Done() + clk.Sleep(250 * time.Millisecond) + ctx := context.Background() + // misses the first batch, will be in separate batch + total, err := sb.Add(ctx, 777) + assert.NoError(t, err) + assert.Equal(t, 777, total) + // end of first process time + 100 (min delay) + 50 (second process time) + assert.Equal(t, targetMS+100+50, clk.Now().UnixMilli()) + }() + + time.Sleep(time.Millisecond) // wait for goroutines to get into Add + clk.AdvanceNext() // first Add + time.Sleep(time.Millisecond) + clk.AdvanceNext() // second Add + time.Sleep(time.Millisecond) + clk.AdvanceNext() // third add + time.Sleep(time.Millisecond) + clk.AdvanceNext() // max delay + time.Sleep(time.Millisecond) + clk.AdvanceNext() // process time + time.Sleep(time.Millisecond) + clk.AdvanceNext() // fourth add + time.Sleep(time.Millisecond) + clk.AdvanceNext() // min delay + time.Sleep(time.Millisecond) + clk.AdvanceNext() // process time + + wg.Wait() +} + +// this test uses time.Sleep to allow goroutines to get into a blocked state +// +//nolint:forbidigo +func TestStreamBatcher_MaxItems(t *testing.T) { + clk := clock.NewEventTimeSource() + + opts := BatcherOptions{ + MaxItems: 2, + MinDelay: 100 * time.Millisecond, + MaxDelay: 400 * time.Millisecond, + IdleTime: 1000 * time.Millisecond, + } + process := func(items []int) (total int) { + for _, i := range items { + total += i + } + clk.Sleep(50 * time.Millisecond) + return + } + sb := NewBatcher(process, opts, clk) + + var wg sync.WaitGroup + wg.Add(2) + + // 200 = 150 (second Add call) + 50 (process time) + targetMS := int64(200) + + go func() { + defer wg.Done() + clk.Sleep(100 * time.Millisecond) + ctx := context.Background() + total, err := sb.Add(ctx, 100) + assert.NoError(t, err) + assert.Equal(t, 123, total) + assert.Equal(t, targetMS, clk.Now().UnixMilli()) + }() + go func() { + defer wg.Done() + clk.Sleep(150 * time.Millisecond) + ctx := context.Background() + total, err := sb.Add(ctx, 23) + assert.NoError(t, err) + assert.Equal(t, 123, total) + assert.Equal(t, targetMS, clk.Now().UnixMilli()) + }() + + time.Sleep(time.Millisecond) // wait for goroutines to get into Add + clk.AdvanceNext() // first Add + time.Sleep(time.Millisecond) + clk.AdvanceNext() // second Add + time.Sleep(time.Millisecond) + clk.AdvanceNext() // process time + + wg.Wait() +} + +func TestStreamBatcher_Random(t *testing.T) { + // throw a lot of concurrent calls at the batcher and make sure there are no errors at + // least. with log statements in stream_batcher.go, you can see this does (or did at some + // point) exercise all of the tricky spots. + + clk := clock.NewEventTimeSource() + + const workers = 20 + const events = 1000 + + opts := BatcherOptions{ + MaxItems: 10, + MinDelay: 100 * time.Millisecond, + MaxDelay: 400 * time.Millisecond, + IdleTime: 100 * time.Millisecond, + } + process := func(items []int) (total int) { + for _, i := range items { + total += i + } + clk.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) + return + } + sb := NewBatcher(process, opts, clk) + + var running atomic.Int64 + for range workers { + running.Add(1) + go func() { + ctx := context.Background() + for range events { + _, err := sb.Add(ctx, 1) + assert.NoError(t, err) + clk.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond) + } + running.Add(-1) + }() + } + + for running.Load() > 0 { + clk.AdvanceNext() + // What we'd really like is a way to say "wait until all of these goroutines are + // blocked on a timer". But that's not quite possible with the Go runtime. Running + // Gosched repeatedly is good enough for this test. + for range workers + 1 { + runtime.Gosched() + } + } +} diff --git a/service/matching/matching_engine.go b/service/matching/matching_engine.go index f8e46091b8c..8cafa1142b2 100644 --- a/service/matching/matching_engine.go +++ b/service/matching/matching_engine.go @@ -56,6 +56,7 @@ import ( hlc "go.temporal.io/server/common/clock/hybrid_logical_clock" "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/collection" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/membership" @@ -68,6 +69,7 @@ import ( "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/resource" serviceerrors "go.temporal.io/server/common/serviceerror" + "go.temporal.io/server/common/stream_batcher" "go.temporal.io/server/common/tasktoken" "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/tqid" @@ -106,9 +108,9 @@ type ( forwardedFrom string } - namespaceUpdateLocks struct { - updateLock sync.Mutex - replicationLock sync.Mutex + userDataUpdate struct { + taskQueue string + update persistence.SingleTaskQueueUserDataUpdate } gaugeMetrics struct { @@ -162,10 +164,10 @@ type ( outstandingPollers collection.SyncMap[string, context.CancelFunc] // Only set if global namespaces are enabled on the cluster. namespaceReplicationQueue persistence.NamespaceReplicationQueue - // Disables concurrent task queue user data updates and replication requests (due to a cassandra limitation) - namespaceUpdateLockMap map[string]*namespaceUpdateLocks - // Serializes access to the per namespace lock map - namespaceUpdateLockMapLock sync.Mutex + // Lock to serialize replication queue updates. + replicationLock sync.Mutex + // Serialize and batch user data updates by namespace. + userDataUpdateBatchers collection.SyncMap[namespace.ID, *stream_batcher.Batcher[*userDataUpdate, error]] // Stores results of reachability queries to visibility reachabilityCache reachabilityCache } @@ -184,6 +186,14 @@ var ( // The routing key for the single partition used to route Nexus endpoints CRUD RPCs to. nexusEndpointsTablePartitionRoutingKey = tqid.MustNormalPartitionFromRpcName("not-applicable", "not-applicable", enumspb.TASK_QUEUE_TYPE_UNSPECIFIED).RoutingKey() + + // Options for batching user data updates. + userDataBatcherOptions = stream_batcher.BatcherOptions{ + MaxItems: 100, + MinDelay: 100 * time.Millisecond, + MaxDelay: 500 * time.Millisecond, + IdleTime: time.Minute, + } ) var _ Engine = (*matchingEngineImpl)(nil) // Asserts that interface is indeed implemented @@ -241,7 +251,7 @@ func NewEngine( nexusResults: collection.NewSyncMap[string, chan *nexusResult](), outstandingPollers: collection.NewSyncMap[string, context.CancelFunc](), namespaceReplicationQueue: namespaceReplicationQueue, - namespaceUpdateLockMap: make(map[string]*namespaceUpdateLocks), + userDataUpdateBatchers: collection.NewSyncMap[namespace.ID, *stream_batcher.Batcher[*userDataUpdate, error]](), } e.reachabilityCache = newReachabilityCache( metrics.NoopMetricsHandler, @@ -1808,18 +1818,35 @@ func (e *matchingEngineImpl) ForceUnloadTaskQueuePartition( } func (e *matchingEngineImpl) UpdateTaskQueueUserData(ctx context.Context, request *matchingservice.UpdateTaskQueueUserDataRequest) (*matchingservice.UpdateTaskQueueUserDataResponse, error) { - locks := e.getNamespaceUpdateLocks(request.GetNamespaceId()) - locks.updateLock.Lock() - defer locks.updateLock.Unlock() - - err := e.taskManager.UpdateTaskQueueUserData(ctx, &persistence.UpdateTaskQueueUserDataRequest{ - NamespaceID: request.GetNamespaceId(), - TaskQueue: request.GetTaskQueue(), - UserData: request.GetUserData(), - BuildIdsAdded: request.BuildIdsAdded, - BuildIdsRemoved: request.BuildIdsRemoved, + namespaceId := namespace.ID(request.NamespaceId) + var conflicting bool + persistenceErr, ctxErr := e.getUserDataBatcher(namespaceId).Add(ctx, &userDataUpdate{ + taskQueue: request.GetTaskQueue(), + update: persistence.SingleTaskQueueUserDataUpdate{ + UserData: request.UserData, + BuildIdsAdded: request.BuildIdsAdded, + BuildIdsRemoved: request.BuildIdsRemoved, + Conflicting: &conflicting, + }, }) - return &matchingservice.UpdateTaskQueueUserDataResponse{}, err + if ctxErr != nil { + // Return context errors as-is. + return nil, ctxErr + } + if persistenceErr != nil { + if persistence.IsConflictErr(persistenceErr) { + if conflicting { + // This specific update was the conflicting one. Use InvalidArgument so the + // caller does not retry. + return nil, serviceerror.NewInvalidArgument(persistenceErr.Error()) + } + // This update may or may not be conflicting. Use Unavailable to allow retries. + return nil, serviceerror.NewUnavailable(persistenceErr.Error()) + } + // Other errors from persistence get returned as-is. + return nil, persistenceErr + } + return &matchingservice.UpdateTaskQueueUserDataResponse{}, nil } func (e *matchingEngineImpl) ReplicateTaskQueueUserData(ctx context.Context, request *matchingservice.ReplicateTaskQueueUserDataRequest) (*matchingservice.ReplicateTaskQueueUserDataResponse, error) { @@ -1827,9 +1854,8 @@ func (e *matchingEngineImpl) ReplicateTaskQueueUserData(ctx context.Context, req return &matchingservice.ReplicateTaskQueueUserDataResponse{}, nil } - locks := e.getNamespaceUpdateLocks(request.GetNamespaceId()) - locks.replicationLock.Lock() - defer locks.replicationLock.Unlock() + e.replicationLock.Lock() + defer e.replicationLock.Unlock() err := e.namespaceReplicationQueue.Publish(ctx, &replicationspb.ReplicationTask{ TaskType: enumsspb.REPLICATION_TASK_TYPE_TASK_QUEUE_USER_DATA, @@ -2139,15 +2165,37 @@ func (e *matchingEngineImpl) notifyNexusEndpointsOwnershipChange() { e.nexusEndpointClient.notifyOwnershipChanged(isOwner) } -func (e *matchingEngineImpl) getNamespaceUpdateLocks(namespaceId string) *namespaceUpdateLocks { - e.namespaceUpdateLockMapLock.Lock() - defer e.namespaceUpdateLockMapLock.Unlock() - locks, found := e.namespaceUpdateLockMap[namespaceId] - if !found { - locks = &namespaceUpdateLocks{} - e.namespaceUpdateLockMap[namespaceId] = locks +func (e *matchingEngineImpl) getUserDataBatcher(namespaceId namespace.ID) *stream_batcher.Batcher[*userDataUpdate, error] { + // Note that values are never removed from this map. The batcher's goroutine will exit + // after the idle time, though, which gets most of the desired resource savings. + if batcher, ok := e.userDataUpdateBatchers.Get(namespaceId); ok { + return batcher } - return locks + fn := func(batch []*userDataUpdate) error { + return e.applyUserDataUpdateBatch(namespaceId, batch) + } + newBatcher := stream_batcher.NewBatcher[*userDataUpdate, error](fn, userDataBatcherOptions, e.timeSource) + batcher, _ := e.userDataUpdateBatchers.GetOrSet(namespaceId, newBatcher) + return batcher +} + +func (e *matchingEngineImpl) applyUserDataUpdateBatch(namespaceId namespace.ID, batch []*userDataUpdate) error { + ctx, cancel := context.WithTimeout(context.Background(), ioTimeout) + // TODO: should use namespace name here + ctx = headers.SetCallerInfo(ctx, headers.NewBackgroundCallerInfo(namespaceId.String())) + defer cancel() + + // convert to map + updatesMap := make(map[string]*persistence.SingleTaskQueueUserDataUpdate) + for _, update := range batch { + updatesMap[update.taskQueue] = &update.update + } + + // now apply the batch of updates + return e.taskManager.UpdateTaskQueueUserData(ctx, &persistence.UpdateTaskQueueUserDataRequest{ + NamespaceID: namespaceId.String(), + Updates: updatesMap, + }) } func (e *matchingEngineImpl) getHostInfo(partitionKey string) (string, error) { diff --git a/service/matching/matching_engine_test.go b/service/matching/matching_engine_test.go index 1eb726e2943..20e35937471 100644 --- a/service/matching/matching_engine_test.go +++ b/service/matching/matching_engine_test.go @@ -2409,8 +2409,11 @@ func (s *matchingEngineSuite) TestGetTaskQueueUserData_ReturnsData() { s.NoError(s.taskManager.UpdateTaskQueueUserData(context.Background(), &persistence.UpdateTaskQueueUserDataRequest{ NamespaceID: namespaceID.String(), - TaskQueue: tq, - UserData: userData, + Updates: map[string]*persistence.SingleTaskQueueUserDataUpdate{ + tq: &persistence.SingleTaskQueueUserDataUpdate{ + UserData: userData, + }, + }, })) userData.Version++ @@ -2435,8 +2438,11 @@ func (s *matchingEngineSuite) TestGetTaskQueueUserData_ReturnsEmpty() { s.NoError(s.taskManager.UpdateTaskQueueUserData(context.Background(), &persistence.UpdateTaskQueueUserDataRequest{ NamespaceID: namespaceID.String(), - TaskQueue: tq, - UserData: userData, + Updates: map[string]*persistence.SingleTaskQueueUserDataUpdate{ + tq: &persistence.SingleTaskQueueUserDataUpdate{ + UserData: userData, + }, + }, })) userData.Version++ @@ -2461,8 +2467,11 @@ func (s *matchingEngineSuite) TestGetTaskQueueUserData_LongPoll_Expires() { s.NoError(s.taskManager.UpdateTaskQueueUserData(context.Background(), &persistence.UpdateTaskQueueUserDataRequest{ NamespaceID: namespaceID.String(), - TaskQueue: tq, - UserData: userData, + Updates: map[string]*persistence.SingleTaskQueueUserDataUpdate{ + tq: &persistence.SingleTaskQueueUserDataUpdate{ + UserData: userData, + }, + }, })) userData.Version++ @@ -2534,8 +2543,11 @@ func (s *matchingEngineSuite) TestGetTaskQueueUserData_LongPoll_WakesUp_From2to3 s.NoError(s.taskManager.UpdateTaskQueueUserData(context.Background(), &persistence.UpdateTaskQueueUserDataRequest{ NamespaceID: namespaceID.String(), - TaskQueue: tq, - UserData: userData, + Updates: map[string]*persistence.SingleTaskQueueUserDataUpdate{ + tq: &persistence.SingleTaskQueueUserDataUpdate{ + UserData: userData, + }, + }, })) userData.Version++ @@ -2615,8 +2627,11 @@ func (s *matchingEngineSuite) TestUpdateUserData_FailsOnKnownVersionMismatch() { err := s.taskManager.UpdateTaskQueueUserData(context.Background(), &persistence.UpdateTaskQueueUserDataRequest{ NamespaceID: namespaceID.String(), - TaskQueue: tq, - UserData: userData, + Updates: map[string]*persistence.SingleTaskQueueUserDataUpdate{ + tq: &persistence.SingleTaskQueueUserDataUpdate{ + UserData: userData, + }, + }, }) s.NoError(err) @@ -2711,10 +2726,13 @@ func (s *matchingEngineSuite) TestDemotedMatch() { err := s.taskManager.UpdateTaskQueueUserData(ctx, &persistence.UpdateTaskQueueUserDataRequest{ NamespaceID: namespaceId, - TaskQueue: tq, - UserData: &persistencespb.VersionedTaskQueueUserData{ - Data: userData, - Version: 34, + Updates: map[string]*persistence.SingleTaskQueueUserDataUpdate{ + tq: &persistence.SingleTaskQueueUserDataUpdate{ + UserData: &persistencespb.VersionedTaskQueueUserData{ + Data: userData, + Version: 34, + }, + }, }, }) s.Assert().NoError(err) @@ -2761,10 +2779,13 @@ func (s *matchingEngineSuite) TestDemotedMatch() { err = s.taskManager.UpdateTaskQueueUserData(ctx, &persistence.UpdateTaskQueueUserDataRequest{ NamespaceID: namespaceId, - TaskQueue: tq, - UserData: &persistencespb.VersionedTaskQueueUserData{ - Data: userData, - Version: 34, + Updates: map[string]*persistence.SingleTaskQueueUserDataUpdate{ + tq: &persistence.SingleTaskQueueUserDataUpdate{ + UserData: &persistencespb.VersionedTaskQueueUserData{ + Data: userData, + Version: 34, + }, + }, }, }) s.NoError(err) @@ -3907,16 +3928,18 @@ func (m *testTaskManager) GetTaskQueueUserData(_ context.Context, request *persi // UpdateTaskQueueUserData implements persistence.TaskManager func (m *testTaskManager) UpdateTaskQueueUserData(_ context.Context, request *persistence.UpdateTaskQueueUserDataRequest) error { - dbq, err := ParsePhysicalTaskQueueKey(request.TaskQueue, request.NamespaceID, enumspb.TASK_QUEUE_TYPE_WORKFLOW) - if err != nil { - return err + for tq, update := range request.Updates { + dbq, err := ParsePhysicalTaskQueueKey(tq, request.NamespaceID, enumspb.TASK_QUEUE_TYPE_WORKFLOW) + if err != nil { + return err + } + tlm := m.getQueueManager(dbq) + tlm.Lock() + newData := common.CloneProto(update.UserData) + newData.Version++ + tlm.userData = newData + tlm.Unlock() } - tlm := m.getQueueManager(dbq) - tlm.Lock() - defer tlm.Unlock() - newData := common.CloneProto(request.UserData) - newData.Version++ - tlm.userData = newData return nil } diff --git a/service/matching/user_data_manager_test.go b/service/matching/user_data_manager_test.go index 83e07bc4e43..d4bd4ba3ead 100644 --- a/service/matching/user_data_manager_test.go +++ b/service/matching/user_data_manager_test.go @@ -104,8 +104,11 @@ func TestUserData_LoadOnInit(t *testing.T) { require.NoError(t, m.store.UpdateTaskQueueUserData(context.Background(), &persistence.UpdateTaskQueueUserDataRequest{ NamespaceID: defaultNamespaceId, - TaskQueue: defaultRootTqID, - UserData: data1, + Updates: map[string]*persistence.SingleTaskQueueUserDataUpdate{ + defaultRootTqID: &persistence.SingleTaskQueueUserDataUpdate{ + UserData: data1, + }, + }, })) data1.Version++ @@ -140,8 +143,11 @@ func TestUserData_LoadOnInit_Refresh(t *testing.T) { require.NoError(t, m.store.UpdateTaskQueueUserData(context.Background(), &persistence.UpdateTaskQueueUserDataRequest{ NamespaceID: defaultNamespaceId, - TaskQueue: defaultRootTqID, - UserData: data1, + Updates: map[string]*persistence.SingleTaskQueueUserDataUpdate{ + defaultRootTqID: &persistence.SingleTaskQueueUserDataUpdate{ + UserData: data1, + }, + }, })) data1.Version++ @@ -167,8 +173,11 @@ func TestUserData_LoadOnInit_Refresh(t *testing.T) { require.NoError(t, m.store.UpdateTaskQueueUserData(context.Background(), &persistence.UpdateTaskQueueUserDataRequest{ NamespaceID: defaultNamespaceId, - TaskQueue: defaultRootTqID, - UserData: data2, + Updates: map[string]*persistence.SingleTaskQueueUserDataUpdate{ + defaultRootTqID: &persistence.SingleTaskQueueUserDataUpdate{ + UserData: data2, + }, + }, })) data2.Version++ @@ -204,8 +213,11 @@ func TestUserData_LoadOnInit_Refresh_Backwards(t *testing.T) { require.NoError(t, m.store.UpdateTaskQueueUserData(context.Background(), &persistence.UpdateTaskQueueUserDataRequest{ NamespaceID: defaultNamespaceId, - TaskQueue: defaultRootTqID, - UserData: data5, + Updates: map[string]*persistence.SingleTaskQueueUserDataUpdate{ + defaultRootTqID: &persistence.SingleTaskQueueUserDataUpdate{ + UserData: data5, + }, + }, })) data5.Version++ @@ -231,8 +243,11 @@ func TestUserData_LoadOnInit_Refresh_Backwards(t *testing.T) { require.NoError(t, m.store.UpdateTaskQueueUserData(context.Background(), &persistence.UpdateTaskQueueUserDataRequest{ NamespaceID: defaultNamespaceId, - TaskQueue: defaultRootTqID, - UserData: data4, + Updates: map[string]*persistence.SingleTaskQueueUserDataUpdate{ + defaultRootTqID: &persistence.SingleTaskQueueUserDataUpdate{ + UserData: data4, + }, + }, })) data4.Version++ @@ -806,11 +821,14 @@ func TestUserData_Propagation(t *testing.T) { opts.matchingClientMock.EXPECT().UpdateTaskQueueUserData(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, req *matchingservice.UpdateTaskQueueUserDataRequest, opts ...grpc.CallOption) (*matchingservice.UpdateTaskQueueUserDataResponse, error) { err := tm.UpdateTaskQueueUserData(ctx, &persistence.UpdateTaskQueueUserDataRequest{ - NamespaceID: req.NamespaceId, - TaskQueue: req.TaskQueue, - UserData: req.UserData, - BuildIdsAdded: req.BuildIdsAdded, - BuildIdsRemoved: req.BuildIdsRemoved, + NamespaceID: req.NamespaceId, + Updates: map[string]*persistence.SingleTaskQueueUserDataUpdate{ + req.TaskQueue: &persistence.SingleTaskQueueUserDataUpdate{ + UserData: req.UserData, + BuildIdsAdded: req.BuildIdsAdded, + BuildIdsRemoved: req.BuildIdsRemoved, + }, + }, }) return &matchingservice.UpdateTaskQueueUserDataResponse{}, err }, @@ -897,11 +915,14 @@ func TestUserData_CheckPropagation(t *testing.T) { opts.matchingClientMock.EXPECT().UpdateTaskQueueUserData(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, req *matchingservice.UpdateTaskQueueUserDataRequest, opts ...grpc.CallOption) (*matchingservice.UpdateTaskQueueUserDataResponse, error) { err := tm.UpdateTaskQueueUserData(ctx, &persistence.UpdateTaskQueueUserDataRequest{ - NamespaceID: req.NamespaceId, - TaskQueue: req.TaskQueue, - UserData: req.UserData, - BuildIdsAdded: req.BuildIdsAdded, - BuildIdsRemoved: req.BuildIdsRemoved, + NamespaceID: req.NamespaceId, + Updates: map[string]*persistence.SingleTaskQueueUserDataUpdate{ + req.TaskQueue: &persistence.SingleTaskQueueUserDataUpdate{ + UserData: req.UserData, + BuildIdsAdded: req.BuildIdsAdded, + BuildIdsRemoved: req.BuildIdsRemoved, + }, + }, }) return &matchingservice.UpdateTaskQueueUserDataResponse{}, err }, diff --git a/tests/versioning_3_test.go b/tests/versioning_3_test.go index 473152212d9..67bbc9e13a3 100644 --- a/tests/versioning_3_test.go +++ b/tests/versioning_3_test.go @@ -878,8 +878,8 @@ func respondWftWithActivities( // TODO (shahab): tests with forced task forward take multiple seconds. Need to know why? ScheduleToCloseTimeout: durationpb.New(10 * time.Second), ScheduleToStartTimeout: durationpb.New(10 * time.Second), - StartToCloseTimeout: durationpb.New(1 * time.Second), - HeartbeatTimeout: durationpb.New(1 * time.Second), + StartToCloseTimeout: durationpb.New(3 * time.Second), + HeartbeatTimeout: durationpb.New(3 * time.Second), RequestEagerExecution: false, }, },