Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add thread-safe message scheduling and related tests #1638

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions src/DotNetCore.CAP/Internal/ScheduledMediumMessageQueue.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using DotNetCore.CAP.Persistence;

namespace DotNetCore.CAP.Internal;

public class ScheduledMediumMessageQueue
{
private readonly SortedSet<(long, MediumMessage)> _queue = new(Comparer<(long, MediumMessage)>.Create((a, b) =>
{
int result = a.Item1.CompareTo(b.Item1);
return result == 0 ? String.Compare(a.Item2.DbId, b.Item2.DbId, StringComparison.Ordinal) : result;
}));

private readonly SemaphoreSlim _semaphore = new(0);
private readonly object _lock = new();

public void Enqueue(MediumMessage message, long sendTime)
{
lock (_lock)
{
_queue.Add((sendTime, message));
}

_semaphore.Release();
}

public int Count
{
get
{
lock (_lock)
{
return _queue.Count;
}
}
}

public IEnumerable<MediumMessage> UnorderedItems
{
get
{
lock (_lock)
{
return _queue.Select(x => x.Item2).ToList();
}
}
}

public async IAsyncEnumerable<MediumMessage> GetConsumingEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default)
{
while (!cancellationToken.IsCancellationRequested)
{
await _semaphore.WaitAsync(cancellationToken);

(long, MediumMessage)? nextItem = null;

lock (_lock)
{
if (_queue.Count > 0)
{
var topMessage = _queue.First();
var timeLeft = topMessage.Item1 - DateTime.Now.Ticks;
if (timeLeft < 500000) // 50ms
{
nextItem = topMessage;
_queue.Remove(topMessage);
}
}
}

if (nextItem is not null)
{
yield return nextItem.Value.Item2;
}
else
{
// Re-release the semaphore if no item is ready yet
_semaphore.Release();
await Task.Delay(50, cancellationToken);
}
}
}
}
36 changes: 15 additions & 21 deletions src/DotNetCore.CAP/Processor/IDispatcher.Default.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,19 @@ namespace DotNetCore.CAP.Processor;

public class Dispatcher : IDispatcher
{
private readonly CancellationTokenSource _delayCts = new();
private readonly ISubscribeExecutor _executor;
private readonly ILogger<Dispatcher> _logger;
private readonly CapOptions _options;
private readonly IMessageSender _sender;
private readonly IDataStorage _storage;
private readonly PriorityQueue<MediumMessage, long> _schedulerQueue;
private readonly ScheduledMediumMessageQueue _schedulerQueue = new();
private readonly bool _enableParallelExecute;
private readonly bool _enableParallelSend;
private readonly int _pChannelSize;

private CancellationTokenSource? _tasksCts;
private Channel<MediumMessage> _publishedChannel = default!;
private Channel<(MediumMessage, ConsumerExecutorDescriptor?)> _receivedChannel = default!;
private long _nextSendTime = DateTime.MaxValue.Ticks;

public Dispatcher(ILogger<Dispatcher> logger, IMessageSender sender, IOptions<CapOptions> options,
ISubscribeExecutor executor, IDataStorage storage)
Expand All @@ -41,7 +39,6 @@ public Dispatcher(ILogger<Dispatcher> logger, IMessageSender sender, IOptions<Ca
_sender = sender;
_options = options.Value;
_executor = executor;
_schedulerQueue = new PriorityQueue<MediumMessage, long>();
_storage = storage;
_enableParallelExecute = options.Value.EnableSubscriberParallelExecute;
_enableParallelSend = options.Value.EnablePublishParallelSend;
Expand All @@ -52,7 +49,6 @@ public async Task Start(CancellationToken stoppingToken)
{
stoppingToken.ThrowIfCancellationRequested();
_tasksCts = CancellationTokenSource.CreateLinkedTokenSource(stoppingToken, CancellationToken.None);
_tasksCts.Token.Register(() => _delayCts.Cancel());

_publishedChannel = Channel.CreateBounded<MediumMessage>(new BoundedChannelOptions(_pChannelSize)
{
Expand Down Expand Up @@ -88,7 +84,7 @@ await Task.WhenAll(Enumerable.Range(0, _options.SubscriberParallelExecuteThreadC
{
if (_schedulerQueue.Count == 0) return;

var messageIds = _schedulerQueue.UnorderedItems.Select(x => x.Element.DbId).ToArray();
var messageIds = _schedulerQueue.UnorderedItems.Select(x => x.DbId).ToArray();
_storage.ChangePublishStateToDelayedAsync(messageIds).GetAwaiter().GetResult();
_logger.LogDebug("Update storage to delayed success of delayed message in memory queue!");
}
Expand All @@ -102,29 +98,32 @@ await Task.WhenAll(Enumerable.Range(0, _options.SubscriberParallelExecuteThreadC
{
try
{
while (_schedulerQueue.TryPeek(out _, out _nextSendTime))
await foreach (var nextMessage in _schedulerQueue.GetConsumingEnumerable(_tasksCts.Token))
{
var delayTime = _nextSendTime - DateTime.Now.Ticks;

if (delayTime > 500000) //50ms
{
await Task.Delay(new TimeSpan(delayTime), _delayCts.Token);
}
_tasksCts.Token.ThrowIfCancellationRequested();

await _sender.SendAsync(_schedulerQueue.Dequeue()).ConfigureAwait(false);
await _sender.SendAsync(nextMessage).ConfigureAwait(false);
}

_tasksCts.Token.WaitHandle.WaitOne(100);
}
catch (OperationCanceledException)
{
//Ignore
}
catch (Exception ex)
{
_logger.LogWarning(ex,
"Scheduled message publishing failed unexpectedly, which will stop future scheduled " +
"messages from publishing. See more details here: https://github.com/dotnetcore/CAP/issues/1637. " +
"Exception: {Message}",
ex.Message);
throw;
}
}
}, _tasksCts.Token).ConfigureAwait(false);
}

public async ValueTask EnqueueToScheduler(MediumMessage message, DateTime publishTime, object? transaction = null)
public async Task EnqueueToScheduler(MediumMessage message, DateTime publishTime, object? transaction = null)
{
message.ExpiresAt = publishTime;

Expand All @@ -135,11 +134,6 @@ public async ValueTask EnqueueToScheduler(MediumMessage message, DateTime publis
await _storage.ChangePublishStateAsync(message, StatusName.Queued, transaction);

_schedulerQueue.Enqueue(message, publishTime.Ticks);

if (publishTime.Ticks < _nextSendTime)
{
_delayCts.Cancel();
}
}
else
{
Expand Down
2 changes: 1 addition & 1 deletion src/DotNetCore.CAP/Transport/IDispatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ public interface IDispatcher : IProcessingServer

ValueTask EnqueueToExecute(MediumMessage message, ConsumerExecutorDescriptor? descriptor = null);

ValueTask EnqueueToScheduler(MediumMessage message, DateTime publishTime, object? transaction = null);
Task EnqueueToScheduler(MediumMessage message, DateTime publishTime, object? transaction = null);
}
179 changes: 179 additions & 0 deletions test/DotNetCore.CAP.Test/DispatcherTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using DotNetCore.CAP.Internal;
using DotNetCore.CAP.Messages;
using DotNetCore.CAP.Persistence;
using DotNetCore.CAP.Processor;
using DotNetCore.CAP.Test.Helpers;
using FluentAssertions;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using NSubstitute;
using Xunit;

namespace DotNetCore.CAP.Test;

public class DispatcherTests
{
private readonly ILogger<Dispatcher> _logger;
private readonly ISubscribeExecutor _executor;
private readonly IDataStorage _storage;

public DispatcherTests()
{
_logger = Substitute.For<ILogger<Dispatcher>>();
_executor = Substitute.For<ISubscribeExecutor>();
_storage = Substitute.For<IDataStorage>();
}

[Fact]
public async Task EnqueueToPublish_ShouldInvokeSend_WhenParallelSendDisabled()
{
// Arrange
var sender = new TestThreadSafeMessageSender();
var options = Options.Create(new CapOptions
{
EnableSubscriberParallelExecute = true,
EnablePublishParallelSend = false,
SubscriberParallelExecuteThreadCount = 2,
SubscriberParallelExecuteBufferFactor = 2
});

var dispatcher = new Dispatcher(_logger, sender, options, _executor, _storage);

using var cts = new CancellationTokenSource();
var messageId = "testId";

// Act
await dispatcher.Start(cts.Token);
await dispatcher.EnqueueToPublish(CreateTestMessage(messageId));
await cts.CancelAsync();

// Assert
sender.Count.Should().Be(1);
sender.ReceivedMessages.First().DbId.Should().Be(messageId);
}

[Fact]
public async Task EnqueueToPublish_ShouldBeThreadSafe_WhenParallelSendDisabled()
{
// Arrange
var sender = new TestThreadSafeMessageSender();
var options = Options.Create(new CapOptions
{
EnableSubscriberParallelExecute = true,
EnablePublishParallelSend = false,
SubscriberParallelExecuteThreadCount = 2,
SubscriberParallelExecuteBufferFactor = 2
});
var dispatcher = new Dispatcher(_logger, sender, options, _executor, _storage);

using var cts = new CancellationTokenSource();
var messages = Enumerable.Range(1, 100)
.Select(i => CreateTestMessage(i.ToString()))
.ToArray();

// Act
await dispatcher.Start(cts.Token);

var tasks = messages
.Select(msg => Task.Run(() => dispatcher.EnqueueToPublish(msg), CancellationToken.None));
await Task.WhenAll(tasks);
await cts.CancelAsync();

// Assert
sender.Count.Should().Be(100);
var receivedMessages = sender.ReceivedMessages.Select(m => m.DbId).Order().ToList();
var expected = messages.Select(m => m.DbId).Order().ToList();
expected.Should().Equal(receivedMessages);
}

[Fact]
public async Task EnqueueToScheduler_ShouldBeThreadSafe_WhenDelayLessThenMinute()
{
// Arrange
var sender = new TestThreadSafeMessageSender();
var options = Options.Create(new CapOptions
{
EnableSubscriberParallelExecute = true,
EnablePublishParallelSend = false,
SubscriberParallelExecuteThreadCount = 2,
SubscriberParallelExecuteBufferFactor = 2
});
var dispatcher = new Dispatcher(_logger, sender, options, _executor, _storage);

using var cts = new CancellationTokenSource();
var messages = Enumerable.Range(1, 10000)
.Select(i => CreateTestMessage(i.ToString()))
.ToArray();

// Act
await dispatcher.Start(cts.Token);
var dateTime = DateTime.Now.AddSeconds(1);
await Parallel.ForEachAsync(messages, CancellationToken.None,
async (m, ct) => { await dispatcher.EnqueueToScheduler(m, dateTime); });

await Task.Delay(1500, CancellationToken.None);

await cts.CancelAsync();

// Assert
sender.Count.Should().Be(10000);

var receivedMessages = sender.ReceivedMessages.Select(m => m.DbId).Order().ToList();
var expected = messages.Select(m => m.DbId).Order().ToList();
expected.Should().Equal(receivedMessages);
}

[Fact]
public async Task EnqueueToScheduler_ShouldSendMessagesInCorrectOrder_WhenEarlierMessageIsSentLater()
{
// Arrange
var sender = new TestThreadSafeMessageSender();
var options = Options.Create(new CapOptions
{
EnableSubscriberParallelExecute = true,
EnablePublishParallelSend = false,
SubscriberParallelExecuteThreadCount = 2,
SubscriberParallelExecuteBufferFactor = 2
});
var dispatcher = new Dispatcher(_logger, sender, options, _executor, _storage);

using var cts = new CancellationTokenSource();
var messages = Enumerable.Range(1, 3)
.Select(i => CreateTestMessage(i.ToString()))
.ToArray();

// Act
await dispatcher.Start(cts.Token);
var dateTime = DateTime.Now;

await dispatcher.EnqueueToScheduler(messages[0], dateTime.AddSeconds(1));
await dispatcher.EnqueueToScheduler(messages[1], dateTime.AddMilliseconds(200));
await dispatcher.EnqueueToScheduler(messages[2], dateTime.AddMilliseconds(100));

await Task.Delay(1200, CancellationToken.None);
await cts.CancelAsync();

// Assert
sender.ReceivedMessages.Select(m => m.DbId).Should().Equal(["3", "2", "1"]);
}


private MediumMessage CreateTestMessage(string id = "1")
{
return new MediumMessage()
{
DbId = id,
Origin = new Message(
headers: new Dictionary<string, string>()
{
{ "cap-msg-id", id }
},
value: new MessageValue("[email protected]", "User"))
};
}
}
2 changes: 2 additions & 0 deletions test/DotNetCore.CAP.Test/DotNetCore.CAP.Test.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="FluentAssertions" Version="7.0.0" />
<PackageReference Include="NSubstitute" Version="5.3.0" />
</ItemGroup>

<ItemGroup>
Expand Down
Loading