Skip to content

Commit

Permalink
Merge pull request #76 from Klotzi111/main
Browse files Browse the repository at this point in the history
Fixed encryption and race conditions, added registration classes to unregister packet handlers and WaitForPacketWhere
  • Loading branch information
psu-de authored Aug 17, 2024
2 parents dffc0a8 + 85f291b commit 77d395b
Show file tree
Hide file tree
Showing 16 changed files with 347 additions and 119 deletions.
212 changes: 149 additions & 63 deletions Components/MineSharp.Protocol/MinecraftClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using MineSharp.Auth;
using MineSharp.ChatComponent;
using MineSharp.ChatComponent.Components;
using MineSharp.Core;
using MineSharp.Core.Common.Protocol;
using MineSharp.Core.Concurrency;
using MineSharp.Core.Events;
Expand All @@ -19,13 +18,12 @@
using MineSharp.Protocol.Packets;
using MineSharp.Protocol.Packets.Clientbound.Status;
using MineSharp.Protocol.Packets.Handlers;
using MineSharp.Protocol.Packets.Serverbound.Configuration;
using MineSharp.Protocol.Packets.Serverbound.Status;
using MineSharp.Protocol.Registrations;
using Newtonsoft.Json.Linq;
using NLog;

using PlayClientInformationPacket = MineSharp.Protocol.Packets.Serverbound.Play.ClientInformationPacket;
using ConfigurationClientInformationPacket = MineSharp.Protocol.Packets.Serverbound.Configuration.ClientInformationPacket;
using PlayClientInformationPacket = MineSharp.Protocol.Packets.Serverbound.Play.ClientInformationPacket;

namespace MineSharp.Protocol;

Expand Down Expand Up @@ -109,7 +107,7 @@ public sealed class MinecraftClient : IAsyncDisposable, IDisposable
private Task? streamLoop;
private int onConnectionLostFired;

private readonly ConcurrentDictionary<PacketType, ConcurrentBag<AsyncPacketHandler>> packetHandlers;
private readonly ConcurrentDictionary<PacketType, ConcurrentHashSet<AsyncPacketHandler>> packetHandlers;
private readonly ConcurrentDictionary<PacketType, TaskCompletionSource<object>> packetWaiters;
private readonly ConcurrentHashSet<AsyncPacketHandler> packetReceivers;
private GameStatePacketHandler gameStatePacketHandler;
Expand Down Expand Up @@ -161,7 +159,7 @@ public MinecraftClient(
packetHandlers = new();
packetWaiters = new();
packetReceivers = new();
GameJoinedTcs = new();
GameJoinedTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
bundledPackets = null;
tcpTcpFactory = tcpFactory;
ip = IpHelper.ResolveHostname(hostnameOrIp, ref port);
Expand Down Expand Up @@ -266,7 +264,7 @@ public async Task<bool> Connect(GameState nextState)
/// <returns>A task that resolves once the packet was actually sent.</returns>
public async Task SendPacket(IPacket packet, CancellationToken cancellation = default)
{
var sendingTask = new PacketSendTask(packet, cancellation, new());
var sendingTask = new PacketSendTask(packet, cancellation, new(TaskCreationOptions.RunContinuationsAsynchronously));
try
{
if (!await packetQueue.SendAsync(sendingTask, cancellation))
Expand Down Expand Up @@ -326,18 +324,107 @@ private async Task DisconnectInternal(Chat? reason = null)
await OnConnectionLost.Dispatch(this);
}

/// <summary>
/// Represents a registration for a packet handler that will be called whenever a packet of type <typeparamref name="T" /> is received.
/// This registration can be used to unregister the handler.
/// </summary>
/// <typeparam name="T">The type of the packet.</typeparam>
public sealed class OnPacketRegistration<T> : AbstractPacketReceiveRegistration
where T : IPacket
{
internal OnPacketRegistration(MinecraftClient client, AsyncPacketHandler handler)
: base(client, handler)
{
}

/// <inheritdoc/>
protected override void Unregister()
{
var key = T.StaticType;
if (Client.packetHandlers.TryGetValue(key, out var handlers))
{
handlers.TryRemove(Handler);
}
}
}

/// <summary>
/// Registers an <paramref name="handler" /> that will be called whenever an packet of type <typeparamref name="T" />
/// is received
/// </summary>
/// <param name="handler">A delegate that will be called when a packet of type T is received</param>
/// <typeparam name="T">The type of the packet</typeparam>
public void On<T>(AsyncPacketHandler<T> handler) where T : IPacket
/// <returns>A registration object that can be used to unregister the handler.</returns>
public OnPacketRegistration<T>? On<T>(AsyncPacketHandler<T> handler) where T : IPacket
{
var key = T.StaticType;
AsyncPacketHandler rawHandler = packet => handler((T)packet);
var added = packetHandlers.GetOrAdd(key, _ => new ConcurrentHashSet<AsyncPacketHandler>())
.Add(rawHandler);
return added ? new(this, rawHandler) : null;
}

packetHandlers.GetOrAdd(key, _ => new ConcurrentBag<AsyncPacketHandler>())
.Add(p => handler((T)p));
/// <summary>
/// Waits until a packet of the specified type is received and matches the given condition.
/// </summary>
/// <typeparam name="T">The type of the packet.</typeparam>
/// <param name="condition">A function that evaluates the packet and returns true if the condition is met.</param>
/// <param name="cancellationToken">A token to cancel the wait for the matching packet.</param>
/// <returns>A task that completes once a packet matching the condition is received.</returns>
public Task WaitForPacketWhere<T>(Func<T, Task<bool>> condition, CancellationToken cancellationToken = default)
where T : IPacket
{
// linked token is required to cancel the task when the client is disconnected
var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, CancellationToken);
var token = cts.Token;
var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
async Task PacketHandler(T packet)
{
try
{
if (tcs.Task.IsCompleted)
{
return;
}
if (await condition(packet).WaitAsync(token))
{
tcs.TrySetResult();
}
}
catch (OperationCanceledException e)
{
tcs.TrySetCanceled(e.CancellationToken);
}
catch (Exception e)
{
tcs.TrySetException(e);
}
}
var packetRegistration = On<T>(PacketHandler);
if (packetRegistration == null)
{
// TODO: Can this occur?
cts.Dispose();
throw new InvalidOperationException("Could not register packet handler");
}
return tcs.Task.ContinueWith(_ =>
{
packetRegistration.Dispose();
cts.Dispose();
}, TaskContinuationOptions.ExecuteSynchronously);
}

/// <summary>
/// Waits until a packet of the specified type is received and matches the given condition.
/// </summary>
/// <typeparam name="T">The type of the packet.</typeparam>
/// <param name="condition">A function that evaluates the packet and returns true if the condition is met.</param>
/// <param name="cancellationToken">A token to cancel the wait for the matching packet.</param>
/// <returns>A task that completes once a packet matching the condition is received.</returns>
public Task WaitForPacketWhere<T>(Func<T, bool> condition, CancellationToken cancellationToken = default)
where T : IPacket
{
return WaitForPacketWhere<T>(packet => Task.FromResult(condition(packet)), cancellationToken);
}

/// <summary>
Expand All @@ -348,30 +435,25 @@ public void On<T>(AsyncPacketHandler<T> handler) where T : IPacket
public Task<T> WaitForPacket<T>() where T : IPacket
{
var packetType = T.StaticType;
var tcs = packetWaiters.GetOrAdd(packetType, _ => new TaskCompletionSource<object>());
var tcs = packetWaiters.GetOrAdd(packetType, _ => new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously));
return tcs.Task.ContinueWith(prev => (T)prev.Result);
}

public sealed class OnPacketReceivedRegistration : IDisposable
/// <summary>
/// Represents a registration for a packet handler that will be called whenever any packet is received.
/// This registration can be used to unregister the handler.
/// </summary>
public sealed class OnPacketReceivedRegistration : AbstractPacketReceiveRegistration
{
private readonly MinecraftClient client;
private readonly AsyncPacketHandler handler;
public bool Disposed { get; private set; }

internal OnPacketReceivedRegistration(MinecraftClient client, AsyncPacketHandler handler)
: base(client, handler)
{
this.client = client;
this.handler = handler;
}

public void Dispose()
/// <inheritdoc/>
protected override void Unregister()
{
if (Disposed)
{
return;
}
client.packetReceivers.TryRemove(handler);
Disposed = true;
Client.packetReceivers.TryRemove(Handler);
}
}

Expand All @@ -382,6 +464,7 @@ public void Dispose()
/// Use this for debugging purposes only.
/// </summary>
/// <param name="handler">A delegate that will be called when a packet is received.</param>
/// <returns>A registration object that can be used to unregister the handler.</returns>
public OnPacketReceivedRegistration? OnPacketReceived(AsyncPacketHandler handler)
{
var added = packetReceivers.Add(handler);
Expand All @@ -397,29 +480,30 @@ public Task WaitForGame()
return GameJoinedTcs.Task;
}

internal Task SendClientInformationPacket()
internal Task SendClientInformationPacket(GameState gameState)
{
IPacket packet = Data.Version.Protocol >= ProtocolVersion.V_1_20_3
? new ConfigurationClientInformationPacket(
IPacket packet = gameState switch
{
GameState.Configuration => new ConfigurationClientInformationPacket(
Settings.Locale,
Settings.ViewDistance,
Settings.ChatMode,
Settings.ColoredChat,
Settings.DisplayedSkinParts,
Settings.MainHand,
Settings.EnableTextFiltering,
Settings.AllowServerListings)
: new PlayClientInformationPacket(
Settings.AllowServerListings),
GameState.Play => new PlayClientInformationPacket(
Settings.Locale,
Settings.ViewDistance,
Settings.ChatMode,
Settings.ColoredChat,
Settings.DisplayedSkinParts,
Settings.MainHand,
Settings.EnableTextFiltering,
Settings.AllowServerListings
);

Settings.AllowServerListings),
_ => throw new NotImplementedException(),
};
return SendPacket(packet);
}

Expand Down Expand Up @@ -551,21 +635,24 @@ private async Task SendPackets()
try
{
DispatchPacket(task.Packet);
// TrySetResult must be run from a different task to prevent blocking the stream loop
// because the task continuation will be executed inline and might block or cause a deadlock
_ = Task.Run(task.Task.TrySetResult);
task.Task.TrySetResult();
}
catch (SocketException e)
catch (OperationCanceledException e)
{
Logger.Error(e, "Encountered exception while dispatching packet {PacketType}", task.Packet.Type);
task.Task.TrySetException(e);
// break the loop to prevent further packets from being sent
// because the connection is probably dead
task.Task.TrySetCanceled(e.CancellationToken);
// we should stop. So we do by rethrowing the exception
throw;
}
catch (Exception e)
{
Logger.Error(e, "Encountered exception while dispatching packet {PacketType}", task.Packet.Type);
task.Task.TrySetException(e);
if (e is SocketException)
{
// break the loop to prevent further packets from being sent
// because the connection is probably dead
throw;
}
}
}
}
Expand Down Expand Up @@ -656,7 +743,7 @@ private async Task HandleIncomingPacket(PacketType packetType, PacketBuffer buff
return;
}

var packet = (IPacket?) await ParsePacket(factory, packetType, buffer);
var packet = (IPacket?)await ParsePacket(factory, packetType, buffer);

if (packet == null)
{
Expand Down Expand Up @@ -739,35 +826,34 @@ public static async Task<ServerStatus> RequestServerStatus(
throw new MineSharpHostException("failed to connect to server");
}

var responseTimeoutCts = new CancellationTokenSource();
using var responseTimeoutCts = new CancellationTokenSource();
var responseTimeoutCancellationToken = responseTimeoutCts.Token;
var taskCompletionSource = new TaskCompletionSource<ServerStatus>();

client.On<StatusResponsePacket>(async packet =>
{
var json = packet.Response;
var response = ServerStatus.FromJToken(JToken.Parse(json), client.Data);
taskCompletionSource.TrySetResult(response);

// the server closes the connection
// after sending the StatusResponsePacket
// so just dispose the client (no point in disconnecting)
await client.DisposeAsync();
});
var statusResponsePacketTask = client.WaitForPacket<StatusResponsePacket>();

await client.SendPacket(new StatusRequestPacket(), responseTimeoutCancellationToken);
await client.SendPacket(new PingRequestPacket(DateTimeOffset.UtcNow.ToUnixTimeMilliseconds()), responseTimeoutCancellationToken);

responseTimeoutCancellationToken.Register(
() =>
{
taskCompletionSource.TrySetCanceled(responseTimeoutCancellationToken);
responseTimeoutCts.Dispose();
});

responseTimeoutCts.CancelAfter(responseTimeout);

return await taskCompletionSource.Task;
var statusResponsePacket = await statusResponsePacketTask.WaitAsync(responseTimeoutCancellationToken);
var json = statusResponsePacket.Response;
var response = ServerStatus.FromJToken(JToken.Parse(json), client.Data);

// the server closes the connection
// after sending the StatusResponsePacket and PingResponsePacket
// so just dispose the client (no point in disconnecting)
try
{
await client.DisposeAsync();
}
catch (Exception)
{
// ignore all errors
// in most cases the exception is an OperationCanceledException because the connection was terminated
}

return response;
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public ConfigurationPacketHandler(MinecraftClient client, MinecraftData data)

public override Task StateEntered()
{
return client.SendClientInformationPacket();
return client.SendClientInformationPacket(GameState);
}

public override Task HandleIncoming(IPacket packet)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ private async Task HandleEncryptionRequest(EncryptionRequestPacket packet)
response = new(sharedSecret, encVerToken, null);
}

_ = client.SendPacket(response)
.ContinueWith(_ => client.EnableEncryption(aes.Key));
await client.SendPacket(response);
client.EnableEncryption(aes.Key);
}

private Task HandleSetCompression(SetCompressionPacket packet)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public override async Task StateEntered()
{
if (data.Version.Protocol <= ProtocolVersion.V_1_20)
{
await client.SendClientInformationPacket();
await client.SendClientInformationPacket(GameState);
}
client.GameJoinedTcs.SetResult();
}
Expand Down
Loading

0 comments on commit 77d395b

Please sign in to comment.