Skip to content

Commit

Permalink
[MWB] - pub-sub spike - added a basic pub / sub worker implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeclayton committed Jan 24, 2025
1 parent 458e5c5 commit 8273130
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 0 deletions.
92 changes: 92 additions & 0 deletions src/modules/MouseWithoutBorders/App/Messaging/PacketConsumer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (c) Microsoft Corporation
// The Microsoft Corporation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;

namespace MouseWithoutBorders.Messaging;

internal sealed class PacketConsumer
{
public PacketConsumer(Func<DATA, CancellationToken, Task> callback)
{
this.Channel = System.Threading.Channels.Channel.CreateBounded<DATA>(
new BoundedChannelOptions(100)
{
SingleWriter = true,
SingleReader = true,
AllowSynchronousContinuations = true,
FullMode = BoundedChannelFullMode.Wait,
});
this.Callback = callback ?? throw new ArgumentNullException(nameof(callback));
}

/// <remarks>
/// Each PacketConsumer has a private channel to store its own copy of messages.
/// When a message is posted to a PacketQueue it gets multiplexed to all the subscribing
/// PacketConsumers.
/// </remarks>
private Channel<DATA> Channel
{
get;
}

private Func<DATA, CancellationToken, Task> Callback
{
get;
}

public int Count
=> this.Channel.Reader.Count;

public async ValueTask WriteAsync(DATA packet, CancellationToken cancellationToken = default)
{
await this.Channel.Writer.WriteAsync(packet, cancellationToken);
}

public bool TryWrite(DATA packet)
{
return this.Channel.Writer.TryWrite(packet);
}

/// <remarks>
/// See https://devblogs.microsoft.com/dotnet/an-introduction-to-system-threading-channels/
/// </remarks>
public async ValueTask StartAsync(CancellationToken cancellationToken = default)
{
var reader = this.Channel.Reader;
while (true)
{
if (!await reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
{
throw new ChannelClosedException();
}

if (reader.TryRead(out var packet))
{
await this.Callback(packet, cancellationToken);
}
}
}

/// <summary>
/// Reads and processes all messages currently on the queue until it is empty.
/// Any messages that arrive while draining will be read and processed as well.
/// Does *not* "Complete" the queue, just leaves it empty.
/// </summary>
public async Task DrainAsync(CancellationToken cancellationToken = default)
{
while (this.Channel.Reader.Count > 0)
{
await Task.Delay(250, cancellationToken);
}
}

public void Stop()
{
this.Channel.Writer.Complete();
}
}
31 changes: 31 additions & 0 deletions src/modules/MouseWithoutBorders/App/Messaging/PacketProducer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation
// The Microsoft Corporation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Threading;
using System.Threading.Tasks;

namespace MouseWithoutBorders.Messaging;

internal sealed class PacketProducer
{
public PacketProducer()
{
this.Queue = new();
}

public PacketQueue Queue
{
get;
}

public async ValueTask WriteAsync(DATA packet, CancellationToken cancellationToken = default)
{
await this.Queue.WriteAsync(packet, cancellationToken);
}

public bool TryWrite(DATA item)
{
return this.Queue.TryWrite(item);
}
}
80 changes: 80 additions & 0 deletions src/modules/MouseWithoutBorders/App/Messaging/PacketQueue.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (c) Microsoft Corporation
// The Microsoft Corporation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Concurrent;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace MouseWithoutBorders.Messaging;

internal sealed class PacketQueue
{
public PacketQueue()
{
this.Consumers = [];
}

private Lock _consumerLock = new();

private ConcurrentBag<PacketConsumer> Consumers
{
get;
set;
}

public void Subscribe(PacketConsumer consumer)
{
ArgumentNullException.ThrowIfNull(consumer);

// we still need to lock because Unsubscribe replaces the instance
// so there's a race condition where Subscribe could be called to
// add a new consumer half way though Unsubscribe already running
// and the new consumer getting lost from the new value.
lock (this._consumerLock)
{
this.Consumers.Add(consumer);
}
}

public void Unsubscribe(PacketConsumer consumer)
{
ArgumentNullException.ThrowIfNull(consumer);

// we still need to lock because Unsubscribe replaces the instance
// so there's a race condition where Subscribe could be called to
// add a new consumer half way though Unsubscribe already running
// and the new consumer getting lost from the new value.
lock (this._consumerLock)
{
this.Consumers = new(
this.Consumers.Where(
entry => !object.ReferenceEquals(entry, consumer)));
}
}

public async ValueTask WriteAsync(DATA packet, CancellationToken cancellationToken = default)
{
// we don't need to lock while enumerating because we don't care too much
// if a single message gets lost while a new consumer is being added
foreach (var consumer in this.Consumers)
{
await consumer.WriteAsync(packet, cancellationToken);
}
}

public bool TryWrite(DATA packet)
{
// we don't need to lock while enumerating because we don't care too much
// if a single message gets lost while a new consumer is being added
var result = true;
foreach (var consumer in this.Consumers)
{
result &= consumer.TryWrite(packet);
}

return result;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) Microsoft Corporation
// The Microsoft Corporation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Concurrent;
using System.Diagnostics;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MouseWithoutBorders.Messaging;

namespace MouseWithoutBorders.UnitTests.Messaging;

public static class PacketProducerTests
{
[TestClass]
public sealed class GeneralTests
{
/// <summary>
/// Performs a basic smoke and performance test, and ensures that the same number of messages
/// posted to a PacketQueue by a PacketProducer are received and processed by multiple PacketConsumers.
/// </summary>
[TestMethod]
public async Task BasicSmokeAndPerformanceTest()
{
// some bookkeeping for the test itself
var messageCount = 1_000_000;
var triggers = new ConcurrentDictionary<string, int>();

// make a producer that we'll use to push messages onto a queue
var producer = new PacketProducer();

// subscribe a first consumer to the producer's queue and start it.
// when invoked, it just updates how many times it's been called so we can make sure no messages get missed
PacketConsumer consumer1 = new(
(DATA packet, CancellationToken cancellationToken) =>
{
triggers.AddOrUpdate(nameof(consumer1), 1, (key, oldValue) => oldValue + 1);
return Task.CompletedTask;
});
producer.Queue.Subscribe(consumer1);
var task1 = Task.Run(() => consumer1.StartAsync());

// subscribe a second consumer to the producer's queue and start it.
// when invoked, it just updates how many times it's been called so we can make sure no messages get missed
PacketConsumer consumer2 = new(
(DATA packet, CancellationToken cancellationToken) =>
{
triggers.AddOrUpdate(nameof(consumer2), 1, (key, oldValue) => oldValue + 1);
return Task.CompletedTask;
});
producer.Queue.Subscribe(consumer2);
var task2 = Task.Run(() => consumer2.StartAsync());

// post a bunch of messages onto the queue
var stopwatch = Stopwatch.StartNew();
for (var i = 0; i < messageCount; i++)
{
await producer.WriteAsync(new());
}

// wait for all the messages to be processed by both consumers
await Task.WhenAll(
consumer1.DrainAsync(),
consumer2.DrainAsync());

// check how long it took to process the messages
// this should typically only be a few thousand milliseconds for about 1,000,000 messages
stopwatch.Stop();
Console.WriteLine($"{messageCount:N0} messages processed in {stopwatch.ElapsedMilliseconds}ms");

// did we miss any messages?
Assert.IsTrue(triggers.ContainsKey(nameof(consumer1)));
Assert.AreEqual(messageCount, triggers[nameof(consumer1)]);
Assert.IsTrue(triggers.ContainsKey(nameof(consumer2)));
Assert.AreEqual(messageCount, triggers[nameof(consumer2)]);

// the test will ideally a *little* bit quicker than this, but we'll set it as
// an upper limit so we don't get lots of false negatives.
var performanceGoal = 4000; // milliseconds
Assert.IsTrue(
stopwatch.ElapsedMilliseconds <= performanceGoal,
$"Time taken was expected to be {performanceGoal}ms or less, but was {stopwatch.ElapsedMilliseconds}ms.");
}
}
}

0 comments on commit 8273130

Please sign in to comment.