Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup Lavalink4Net.DSharpPlus. #138

Merged
merged 7 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
namespace Lavalink4NET.DSharpPlus;

using System.Reflection;
using global::DSharpPlus;
using global::DSharpPlus.Entities;
using global::DSharpPlus.Net.WebSocket;

/// <summary>
/// An utility for getting internal / private fields from DSharpPlus WebSocket Gateway Payloads.
/// </summary>
public static class DSharpUtil
public static partial class DSharpPlusUtilities
{
/// <summary>
/// The internal "SessionId" property info in <see cref="DiscordVoiceState"/>.
Expand All @@ -19,27 +17,11 @@ public static class DSharpUtil
.GetProperty("SessionId", BindingFlags.NonPublic | BindingFlags.Instance)!;
#pragma warning restore S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields


/// <summary>
/// The internal "_webSocketClient" field info in <see cref="DiscordClient"/>.
/// </summary>
// https://github.com/DSharpPlus/DSharpPlus/blob/master/DSharpPlus/Clients/DiscordClient.WebSocket.cs#L54
private static readonly FieldInfo _webSocketClientField = typeof(DiscordClient)
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
.GetField("_webSocketClient", BindingFlags.NonPublic | BindingFlags.Instance)!;
#pragma warning restore S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields

/// <summary>
/// Gets the internal "SessionId" property value of the specified <paramref name="voiceState"/>.
/// </summary>
/// <param name="voiceState">the instance</param>
/// <returns>the "SessionId" value</returns>
public static string GetSessionId(this DiscordVoiceState voiceState)
=> (string)_sessionIdProperty.GetValue(voiceState)!;

/// <summary>
/// Gets the internal "_webSocketClient" field value of the specified <paramref name="client"/>.
/// </summary>
public static IWebSocketClient GetWebSocketClient(this DiscordClient client)
=> (IWebSocketClient)_webSocketClientField.GetValue(client)!;
}
178 changes: 100 additions & 78 deletions src/Lavalink4NET.DSharpPlus/DiscordClientWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,81 +2,74 @@ namespace Lavalink4NET.DSharpPlus;

using System;
using System.Collections.Immutable;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Threading;
using System.Threading.Tasks;
using global::DSharpPlus;
using global::DSharpPlus.Entities;
using global::DSharpPlus.EventArgs;
using global::DSharpPlus.Exceptions;
using global::DSharpPlus.Net.Abstractions;
using Lavalink4NET.Clients;
using Lavalink4NET.Clients.Events;
using Lavalink4NET.Events;
using Microsoft.Extensions.Logging;

/// <summary>
/// Wraps a <see cref="DiscordClient"/> or <see cref="DiscordShardedClient"/> instance.
/// </summary>
public sealed class DiscordClientWrapper : IDiscordClientWrapper, IDisposable
{
/// <inheritdoc/>
public event AsyncEventHandler<VoiceServerUpdatedEventArgs>? VoiceServerUpdated;

/// <inheritdoc/>
public event AsyncEventHandler<VoiceStateUpdatedEventArgs>? VoiceStateUpdated;

private readonly object _client; // either DiscordShardedClient or DiscordClient
private readonly ILogger<DiscordClientWrapper> _logger;
private readonly TaskCompletionSource<ClientInformation> _readyTaskCompletionSource;
private bool _disposed;

public DiscordClientWrapper(DiscordClient discordClient)
private DiscordClientWrapper(object discordClient, ILogger<DiscordClientWrapper> logger)
{
ArgumentNullException.ThrowIfNull(discordClient);
ArgumentNullException.ThrowIfNull(logger);

_client = discordClient;
_readyTaskCompletionSource = new TaskCompletionSource<ClientInformation>(TaskCreationOptions.RunContinuationsAsynchronously);
_logger = logger;

discordClient.VoiceStateUpdated += OnVoiceStateUpdated;
discordClient.VoiceServerUpdated += OnVoiceServerUpdated;
discordClient.Ready += OnClientReady;
_readyTaskCompletionSource = new TaskCompletionSource<ClientInformation>(TaskCreationOptions.RunContinuationsAsynchronously);
}

public DiscordClientWrapper(DiscordShardedClient discordClient)
/// <summary>
/// Creates a new instance of <see cref="DiscordClientWrapper"/>.
/// </summary>
/// <param name="discordClient">The Discord Client to wrap.</param>
/// <param name="logger">a logger associated with this wrapper.</param>
public DiscordClientWrapper(DiscordClient discordClient, ILogger<DiscordClientWrapper> logger)
: this((object)discordClient, logger)
{
ArgumentNullException.ThrowIfNull(discordClient);

_client = discordClient;
_readyTaskCompletionSource = new TaskCompletionSource<ClientInformation>(TaskCreationOptions.RunContinuationsAsynchronously);

discordClient.VoiceStateUpdated += OnVoiceStateUpdated;
discordClient.VoiceServerUpdated += OnVoiceServerUpdated;
discordClient.Ready += OnClientReady;
}

/// <inheritdoc/>
public event AsyncEventHandler<VoiceServerUpdatedEventArgs>? VoiceServerUpdated;

/// <inheritdoc/>
public event AsyncEventHandler<VoiceStateUpdatedEventArgs>? VoiceStateUpdated;

/// <inheritdoc/>
public void Dispose()
/// <summary>
/// Creates a new instance of <see cref="DiscordClientWrapper"/>.
/// </summary>
/// <param name="shardedDiscordClient">The Sharded Discord Client to wrap.</param>
/// <param name="logger">a logger associated with this wrapper.</param>
public DiscordClientWrapper(DiscordShardedClient shardedDiscordClient, ILogger<DiscordClientWrapper> logger)
: this((object)shardedDiscordClient, logger)
{
if (_disposed)
{
return;
}

_disposed = true;

if (_client is DiscordClient discordClient)
{
discordClient.VoiceStateUpdated -= OnVoiceStateUpdated;
discordClient.VoiceServerUpdated -= OnVoiceServerUpdated;
discordClient.Ready -= OnClientReady;
}
else
{
var shardedClient = Unsafe.As<object, DiscordShardedClient>(ref Unsafe.AsRef(_client));

shardedClient.VoiceStateUpdated -= OnVoiceStateUpdated;
shardedClient.VoiceServerUpdated -= OnVoiceServerUpdated;
shardedClient.Ready -= OnClientReady;
}
ArgumentNullException.ThrowIfNull(shardedDiscordClient);

shardedDiscordClient.VoiceStateUpdated += OnVoiceStateUpdated;
shardedDiscordClient.VoiceServerUpdated += OnVoiceServerUpdated;
shardedDiscordClient.Ready += OnClientReady;
}

/// <inheritdoc/>
Expand All @@ -95,35 +88,34 @@ public async ValueTask<ImmutableArray<ulong>> GetChannelUsersAsync(
channel = await GetClientForGuild(guildId)
.GetChannelAsync(voiceChannelId)
.ConfigureAwait(false);

if (channel is null)
{
return ImmutableArray<ulong>.Empty;
}
}
catch (UnauthorizedException)
catch (DiscordException exception)
{
// The channel was possibly deleted
return ImmutableArray<ulong>.Empty;
}
_logger.LogWarning(
exception, "An error occurred while retrieving the users for voice channel '{VoiceChannelId}' of the guild '{GuildId}'.",
voiceChannelId, guildId);

if (channel is null)
{
return ImmutableArray<ulong>.Empty;
}

var usersEnumerable = channel.Users.AsEnumerable();
var filteredUsers = ImmutableArray.CreateBuilder<ulong>(channel.Users.Count);

if (includeBots)
foreach (var member in channel.Users)
{
var currentUserId = _client is DiscordClient discordClient
? discordClient.CurrentUser.Id
: ((DiscordShardedClient)_client).CurrentUser.Id;

usersEnumerable = usersEnumerable.Where(x => x.Id != currentUserId);
}
else
{
usersEnumerable = usersEnumerable.Where(x => !x.IsBot);
// Always skip the current user.
// If we're not including bots and the member is a bot, skip them.
if (!member.IsCurrent || includeBots || !member.IsBot)
{
filteredUsers.Add(member.Id);
}
}


return usersEnumerable.Select(s => s.Id).ToImmutableArray();
return filteredUsers.ToImmutable();
}

/// <inheritdoc/>
Expand All @@ -137,13 +129,22 @@ public async ValueTask SendVoiceUpdateAsync(
{
cancellationToken.ThrowIfCancellationRequested();

var payload = new JsonObject();
var data = new VoiceStateUpdatePayload(guildId, voiceChannelId, selfMute, selfDeaf);

payload.Add("op", 4);
payload.Add("d", JsonSerializer.SerializeToNode(data));

await GetClientForGuild(guildId).GetWebSocketClient().SendMessageAsync(payload.ToString()).ConfigureAwait(false);
var client = GetClientForGuild(guildId);

var payload = new VoiceStateUpdatePayload(
guildId: guildId,
channelId: voiceChannelId,
isSelfMuted: selfMute,
isSelfDeafened: selfDeaf);

#pragma warning disable CS0618 // This method should not be used unless you know what you're doing. Instead, look towards the other explicitly implemented methods which come with client-side validation.
// Jan 23, 2024, OoLunar: We're telling Discord that we're joining a voice channel.
// At the time of writing, both DSharpPlus.VoiceNext and DSharpPlus.VoiceLink™
// use this method to send voice state updates.
await client
.SendPayloadAsync(GatewayOpCode.VoiceStateUpdate, JsonSerializer.Serialize(payload))
.ConfigureAwait(false);
#pragma warning restore CS0618 // This method should not be used unless you know what you're doing. Instead, look towards the other explicitly implemented methods which come with client-side validation.
}

/// <inheritdoc/>
Expand All @@ -153,16 +154,34 @@ public ValueTask<ClientInformation> WaitForReadyAsync(CancellationToken cancella
return new(_readyTaskCompletionSource.Task.WaitAsync(cancellationToken));
}

private DiscordClient GetClientForGuild(ulong guildId)
/// <inheritdoc/>
public void Dispose()
{
if (_client is DiscordClient discordClient)
if (_disposed)
{
return discordClient;
return;
}

return Unsafe.As<object, DiscordShardedClient>(ref Unsafe.AsRef(_client)).GetShard(guildId);
_disposed = true;

if (_client is DiscordClient discordClient)
{
discordClient.VoiceStateUpdated -= OnVoiceStateUpdated;
discordClient.VoiceServerUpdated -= OnVoiceServerUpdated;
discordClient.Ready -= OnClientReady;
}
else if (_client is DiscordShardedClient shardedClient)
{
shardedClient.VoiceStateUpdated -= OnVoiceStateUpdated;
shardedClient.VoiceServerUpdated -= OnVoiceServerUpdated;
shardedClient.Ready -= OnClientReady;
}
}

private DiscordClient GetClientForGuild(ulong guildId) => _client is DiscordClient discordClient
? discordClient
: ((DiscordShardedClient)_client).GetShard(guildId);

private Task OnClientReady(DiscordClient discordClient, ReadyEventArgs eventArgs)
{
ArgumentNullException.ThrowIfNull(discordClient);
Expand All @@ -174,11 +193,10 @@ private Task OnClientReady(DiscordClient discordClient, ReadyEventArgs eventArgs
ShardCount: discordClient.ShardCount);

_readyTaskCompletionSource.TrySetResult(clientInformation);

return Task.CompletedTask;
}

private Task OnVoiceServerUpdated(DiscordClient discordClient, VoiceServerUpdateEventArgs voiceServerUpdateEventArgs)
private async Task OnVoiceServerUpdated(DiscordClient discordClient, VoiceServerUpdateEventArgs voiceServerUpdateEventArgs)
{
ArgumentNullException.ThrowIfNull(discordClient);
ArgumentNullException.ThrowIfNull(voiceServerUpdateEventArgs);
Expand All @@ -191,17 +209,19 @@ private Task OnVoiceServerUpdated(DiscordClient discordClient, VoiceServerUpdate
guildId: voiceServerUpdateEventArgs.Guild.Id,
voiceServer: server);

return VoiceServerUpdated.InvokeAsync(this, eventArgs).AsTask();
await VoiceServerUpdated
.InvokeAsync(this, eventArgs)
.ConfigureAwait(false);
}
private Task OnVoiceStateUpdated(DiscordClient discordClient, VoiceStateUpdateEventArgs voiceStateUpdateEventArgs)

private async Task OnVoiceStateUpdated(DiscordClient discordClient, VoiceStateUpdateEventArgs voiceStateUpdateEventArgs)
{
ArgumentNullException.ThrowIfNull(discordClient);
ArgumentNullException.ThrowIfNull(voiceStateUpdateEventArgs);

// session id is the same as the resume key so DSharpPlus should be able to give us the
// session key in either before or after voice state
var sessionId = voiceStateUpdateEventArgs.Before?.GetSessionId()
?? voiceStateUpdateEventArgs.After.GetSessionId();
var sessionId = voiceStateUpdateEventArgs.Before?.GetSessionId() ?? voiceStateUpdateEventArgs.After.GetSessionId();

// create voice state
var voiceState = new VoiceState(
Expand All @@ -220,6 +240,8 @@ private Task OnVoiceStateUpdated(DiscordClient discordClient, VoiceStateUpdateEv
oldVoiceState: oldVoiceState,
voiceState: voiceState);

return VoiceStateUpdated.InvokeAsync(this, eventArgs).AsTask();
await VoiceStateUpdated
.InvokeAsync(this, eventArgs)
.ConfigureAwait(false);
}
}
19 changes: 6 additions & 13 deletions src/Lavalink4NET.DSharpPlus/Lavalink4NET.DSharpPlus.csproj
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<OutputType>Library</OutputType>
<TargetFrameworks>net6.0;net7.0</TargetFrameworks>

<TargetFrameworks>net6.0;net7.0</TargetFrameworks>
<LangVersion>latest</LangVersion>
<!-- Package Description -->
<Description>
High performance Lavalink wrapper for .NET | Add powerful audio playback to your DSharpPlus-based applications with this integration for Lavalink4NET. Suitable for end users developing with DSharpPlus.
</Description>

<Description>High performance Lavalink wrapper for .NET | Add powerful audio playback to your DSharpPlus-based applications with this integration for Lavalink4NET. Suitable for end users developing with DSharpPlus.</Description>
<PackageTags>lavalink,lavalink-wrapper,discord,discord-music,discord-music-bot,dsharpplus</PackageTags>

<!-- Documentation -->
<GenerateDocumentationFile>true</GenerateDocumentationFile>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="DSharpPlus" Version="4.4.1" />
<ProjectReference Include="..\Lavalink4NET\Lavalink4NET.csproj" />
<PackageReference Include="DSharpPlus" Version="4.4.6" />
<ProjectReference Include="../Lavalink4NET/Lavalink4NET.csproj" />
</ItemGroup>

<Import Project="../Lavalink4NET.targets" />
</Project>
</Project>
8 changes: 8 additions & 0 deletions src/Lavalink4NET.DSharpPlus/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,16 @@
using Lavalink4NET.DSharpPlus;
using Microsoft.Extensions.DependencyInjection;

/// <summary>
/// A collection of extension methods for <see cref="IServiceCollection"/>.
/// </summary>
public static class ServiceCollectionExtensions
{
/// <summary>
/// Adds the Lavalink4NET DSharpPlus extension to the service collection.
/// </summary>
/// <param name="services">The service collection to add the extension to.</param>
/// <returns>The service collection for chaining.</returns>
public static IServiceCollection AddLavalink(this IServiceCollection services)
{
ArgumentNullException.ThrowIfNull(services);
Expand Down
Loading
Loading