Skip to content

Commit

Permalink
feat(client-support/netcord): Implement support for sharded clients
Browse files Browse the repository at this point in the history
  • Loading branch information
angelobreuer committed Feb 28, 2024
1 parent 9e5a41f commit c0a4c24
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 84 deletions.
104 changes: 21 additions & 83 deletions src/Lavalink4NET.NetCord/DiscordClientWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,123 +2,61 @@

using System;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using global::NetCord.Gateway;
using Lavalink4NET.Clients;
using Lavalink4NET.Clients.Events;
using Lavalink4NET.Events;

public sealed class DiscordClientWrapper : IDiscordClientWrapper, IDisposable
public sealed class DiscordClientWrapper : IDiscordClientWrapper
{
private readonly GatewayClient _client;
private readonly IDiscordClientWrapper _client;

public DiscordClientWrapper(GatewayClient client)
{
ArgumentNullException.ThrowIfNull(client);

_client = client;

_client.VoiceStateUpdate += HandleVoiceStateUpdateAsync;
_client.VoiceServerUpdate += HandleVoiceServerUpdateAsync;
_client = new SocketDiscordClientWrapper(client);
}

public event AsyncEventHandler<VoiceServerUpdatedEventArgs>? VoiceServerUpdated;

public event AsyncEventHandler<VoiceStateUpdatedEventArgs>? VoiceStateUpdated;

public void Dispose()
public DiscordClientWrapper(ShardedGatewayClient client)
{
_client.VoiceStateUpdate -= HandleVoiceStateUpdateAsync;
_client.VoiceServerUpdate -= HandleVoiceServerUpdateAsync;
ArgumentNullException.ThrowIfNull(client);

_client = new ShardedDiscordClientWrapper(client);
}

public ValueTask<ImmutableArray<ulong>> GetChannelUsersAsync(ulong guildId, ulong voiceChannelId, bool includeBots = false, CancellationToken cancellationToken = default)
public event AsyncEventHandler<VoiceServerUpdatedEventArgs>? VoiceServerUpdated
{
cancellationToken.ThrowIfCancellationRequested();

if (!_client.Cache.Guilds.TryGetValue(guildId, out var guild))
{
return new ValueTask<ImmutableArray<ulong>>([]);
}

var voiceStates = guild.VoiceStates
.Where(x => x.Value.ChannelId == voiceChannelId)
.Where(x => x.Value.UserId != _client.Id);

if (!includeBots)
{
voiceStates = voiceStates.Where(x => x.Value.User is not { IsBot: true, });
}

var userIds = voiceStates.Select(x => x.Value.UserId).ToImmutableArray();
return new ValueTask<ImmutableArray<ulong>>(userIds);
add => _client.VoiceServerUpdated += value;
remove => _client.VoiceServerUpdated -= value;
}

public async ValueTask SendVoiceUpdateAsync(ulong guildId, ulong? voiceChannelId, bool selfDeaf = false, bool selfMute = false, CancellationToken cancellationToken = default)
public event AsyncEventHandler<VoiceStateUpdatedEventArgs>? VoiceStateUpdated
{
cancellationToken.ThrowIfCancellationRequested();

var voiceStateProperties = new VoiceStateProperties(guildId, voiceChannelId) { SelfDeaf = selfDeaf, SelfMute = selfMute, };

await _client
.UpdateVoiceStateAsync(voiceStateProperties)
.ConfigureAwait(false);
add => _client.VoiceStateUpdated += value;
remove => _client.VoiceStateUpdated -= value;
}

public async ValueTask<ClientInformation> WaitForReadyAsync(CancellationToken cancellationToken = default)
public ValueTask<ImmutableArray<ulong>> GetChannelUsersAsync(ulong guildId, ulong voiceChannelId, bool includeBots = false, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();

await _client.ReadyAsync
.WaitAsync(cancellationToken)
.ConfigureAwait(false);

var shardCount = _client.Shard?.Count ?? 1;

return new ClientInformation("NetCord", _client.Id, shardCount);
return _client.GetChannelUsersAsync(guildId, voiceChannelId, includeBots, cancellationToken);
}

private ValueTask HandleVoiceServerUpdateAsync(VoiceServerUpdateEventArgs eventArgs)
public ValueTask SendVoiceUpdateAsync(ulong guildId, ulong? voiceChannelId, bool selfDeaf = false, bool selfMute = false, CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(eventArgs);

if (eventArgs.Endpoint is null)
{
return default;
}

var voiceServerUpdatedEventArgs = new VoiceServerUpdatedEventArgs(
guildId: eventArgs.GuildId,
voiceServer: new VoiceServer(eventArgs.Token, eventArgs.Endpoint));
cancellationToken.ThrowIfCancellationRequested();

return VoiceServerUpdated.InvokeAsync(this, voiceServerUpdatedEventArgs);
return _client.SendVoiceUpdateAsync(guildId, voiceChannelId, selfDeaf, selfMute, cancellationToken);
}

private async ValueTask HandleVoiceStateUpdateAsync(global::NetCord.Gateway.VoiceState eventArgs)
public ValueTask<ClientInformation> WaitForReadyAsync(CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(eventArgs);

// Retrieve previous voice state from cache
var previousVoiceState = _client.Cache.Guilds.TryGetValue(eventArgs.GuildId, out var guild)
&& guild.VoiceStates.TryGetValue(eventArgs.UserId, out var previousVoiceStateData)
? new Clients.VoiceState(VoiceChannelId: previousVoiceStateData.ChannelId, SessionId: previousVoiceStateData.SessionId)
: default;

var updatedVoiceState = new Clients.VoiceState(
VoiceChannelId: eventArgs.ChannelId,
SessionId: eventArgs.SessionId);

var voiceStateUpdatedEventArgs = new VoiceStateUpdatedEventArgs(
eventArgs.GuildId,
eventArgs.UserId,
eventArgs.UserId == _client.Id,
updatedVoiceState,
previousVoiceState);
cancellationToken.ThrowIfCancellationRequested();

await VoiceStateUpdated
.InvokeAsync(this, voiceStateUpdatedEventArgs)
.ConfigureAwait(false);
return _client.WaitForReadyAsync(cancellationToken);
}
}
104 changes: 104 additions & 0 deletions src/Lavalink4NET.NetCord/DiscordClientWrapperBase.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
namespace Lavalink4NET.NetCord;

using System;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using global::NetCord.Gateway;
using Lavalink4NET.Clients;
using Lavalink4NET.Clients.Events;
using Lavalink4NET.Events;

internal abstract class DiscordClientWrapperBase : IDiscordClientWrapper
{
public event AsyncEventHandler<VoiceServerUpdatedEventArgs>? VoiceServerUpdated;

public event AsyncEventHandler<VoiceStateUpdatedEventArgs>? VoiceStateUpdated;

public ValueTask<ImmutableArray<ulong>> GetChannelUsersAsync(ulong guildId, ulong voiceChannelId, bool includeBots = false, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();

if (!TryGetGuild(guildId, out var guild))
{
return new ValueTask<ImmutableArray<ulong>>([]);
}

var currentUserId = GetClient(guildId).Id;

var voiceStates = guild.VoiceStates
.Where(x => x.Value.ChannelId == voiceChannelId)
.Where(x => x.Value.UserId != currentUserId);

if (!includeBots)
{
voiceStates = voiceStates.Where(x => x.Value.User is not { IsBot: true, });
}

var userIds = voiceStates.Select(x => x.Value.UserId).ToImmutableArray();
return new ValueTask<ImmutableArray<ulong>>(userIds);
}

public async ValueTask SendVoiceUpdateAsync(ulong guildId, ulong? voiceChannelId, bool selfDeaf = false, bool selfMute = false, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();

var voiceStateProperties = new VoiceStateProperties(guildId, voiceChannelId) { SelfDeaf = selfDeaf, SelfMute = selfMute, };

await GetClient(guildId)
.UpdateVoiceStateAsync(voiceStateProperties)
.ConfigureAwait(false);
}

protected abstract bool TryGetGuild(ulong guildId, [MaybeNullWhen(false)] out Guild guild);

protected abstract GatewayClient GetClient(ulong guildId);

protected ValueTask HandleVoiceServerUpdateAsync(VoiceServerUpdateEventArgs eventArgs)
{
ArgumentNullException.ThrowIfNull(eventArgs);

if (eventArgs.Endpoint is null)
{
return default;
}

var voiceServerUpdatedEventArgs = new VoiceServerUpdatedEventArgs(
guildId: eventArgs.GuildId,
voiceServer: new VoiceServer(eventArgs.Token, eventArgs.Endpoint));

return VoiceServerUpdated.InvokeAsync(this, voiceServerUpdatedEventArgs);
}

protected async ValueTask HandleVoiceStateUpdateAsync(global::NetCord.Gateway.VoiceState eventArgs)
{
ArgumentNullException.ThrowIfNull(eventArgs);

// Retrieve previous voice state from cache
var previousVoiceState = TryGetGuild(eventArgs.GuildId, out var guild)
&& guild.VoiceStates.TryGetValue(eventArgs.UserId, out var previousVoiceStateData)
? new Clients.VoiceState(VoiceChannelId: previousVoiceStateData.ChannelId, SessionId: previousVoiceStateData.SessionId)
: default;

var currentUserId = GetClient(eventArgs.GuildId).Id;

var updatedVoiceState = new Clients.VoiceState(
VoiceChannelId: eventArgs.ChannelId,
SessionId: eventArgs.SessionId);

var voiceStateUpdatedEventArgs = new VoiceStateUpdatedEventArgs(
eventArgs.GuildId,
eventArgs.UserId,
eventArgs.UserId == currentUserId,
updatedVoiceState,
previousVoiceState);

await VoiceStateUpdated
.InvokeAsync(this, voiceStateUpdatedEventArgs)
.ConfigureAwait(false);
}

public abstract ValueTask<ClientInformation> WaitForReadyAsync(CancellationToken cancellationToken = default);
}
2 changes: 1 addition & 1 deletion src/Lavalink4NET.NetCord/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public static IServiceCollection AddLavalink(this IServiceCollection services)
{
ArgumentNullException.ThrowIfNull(services);

services.AddLavalink<DiscordClientWrapper>();
services.AddLavalink<SocketDiscordClientWrapper>();

return services;
}
Expand Down
77 changes: 77 additions & 0 deletions src/Lavalink4NET.NetCord/ShardedDiscordClientWrapper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
namespace Lavalink4NET.NetCord;

using System;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;
using global::NetCord.Gateway;
using Lavalink4NET.Clients;

internal sealed class ShardedDiscordClientWrapper : DiscordClientWrapperBase, IDiscordClientWrapper, IDisposable
{
private readonly ShardedGatewayClient _client;
private readonly TaskCompletionSource _readyTaskCompletionSource;

public ShardedDiscordClientWrapper(ShardedGatewayClient client)
{
ArgumentNullException.ThrowIfNull(client);

_client = client;

_readyTaskCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);

_client.VoiceStateUpdate += HandleVoiceStateUpdateAsync;
_client.VoiceServerUpdate += HandleVoiceServerUpdateAsync;
_client.Ready += HandleShardReadyAsync;
}

private ValueTask HandleShardReadyAsync(GatewayClient client, ReadyEventArgs eventArgs)
{
ArgumentNullException.ThrowIfNull(client);
ArgumentNullException.ThrowIfNull(eventArgs);

_readyTaskCompletionSource.TrySetResult();

return default;
}

private ValueTask HandleVoiceServerUpdateAsync(GatewayClient client, VoiceServerUpdateEventArgs args)
{
ArgumentNullException.ThrowIfNull(client);
ArgumentNullException.ThrowIfNull(args);

return HandleVoiceServerUpdateAsync(args);
}

private ValueTask HandleVoiceStateUpdateAsync(GatewayClient client, global::NetCord.Gateway.VoiceState state)
{
ArgumentNullException.ThrowIfNull(client);
ArgumentNullException.ThrowIfNull(state);

return HandleVoiceStateUpdateAsync(state);
}

public void Dispose()
{
_client.VoiceStateUpdate -= HandleVoiceStateUpdateAsync;
_client.VoiceServerUpdate -= HandleVoiceServerUpdateAsync;
}

public override async ValueTask<ClientInformation> WaitForReadyAsync(CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();

await _readyTaskCompletionSource.Task
.WaitAsync(cancellationToken)
.ConfigureAwait(false);

return new ClientInformation("NetCord", _client.Id, _client.Count);
}

protected override GatewayClient GetClient(ulong guildId) => _client[guildId: guildId];

protected override bool TryGetGuild(ulong guildId, [MaybeNullWhen(false)] out Guild guild)
{
return GetClient(guildId).Cache.Guilds.TryGetValue(guildId, out guild);
}
}
49 changes: 49 additions & 0 deletions src/Lavalink4NET.NetCord/SocketDiscordClientWrapper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
namespace Lavalink4NET.NetCord;

using System;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;
using global::NetCord.Gateway;
using Lavalink4NET.Clients;

internal sealed class SocketDiscordClientWrapper : DiscordClientWrapperBase, IDiscordClientWrapper, IDisposable
{
private readonly GatewayClient _client;

public SocketDiscordClientWrapper(GatewayClient client)
{
ArgumentNullException.ThrowIfNull(client);

_client = client;

_client.VoiceStateUpdate += HandleVoiceStateUpdateAsync;
_client.VoiceServerUpdate += HandleVoiceServerUpdateAsync;
}

public void Dispose()
{
_client.VoiceStateUpdate -= HandleVoiceStateUpdateAsync;
_client.VoiceServerUpdate -= HandleVoiceServerUpdateAsync;
}

public override async ValueTask<ClientInformation> WaitForReadyAsync(CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();

await _client.ReadyAsync
.WaitAsync(cancellationToken)
.ConfigureAwait(false);

var shardCount = _client.Shard?.Count ?? 1;

return new ClientInformation("NetCord", _client.Id, shardCount);
}

protected override GatewayClient GetClient(ulong guildId) => _client;

protected override bool TryGetGuild(ulong guildId, [MaybeNullWhen(false)] out Guild guild)
{
return _client.Cache.Guilds.TryGetValue(guildId, out guild);
}
}

0 comments on commit c0a4c24

Please sign in to comment.