Skip to content

Commit

Permalink
Move Cancellation into new ConsumerAgent
Browse files Browse the repository at this point in the history
  • Loading branch information
NooNameR authored and phatboyg committed Feb 26, 2023
1 parent 3913ba8 commit f01d817
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 67 deletions.
47 changes: 40 additions & 7 deletions src/MassTransit/Transports/ConsumerAgent.cs
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
namespace MassTransit.Transports
{
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Internals;
using Middleware;
using Util;


public abstract class ConsumerAgent :
public abstract class ConsumerAgent<TKey> :
Agent,
DeliveryMetrics
{
readonly ReceiveEndpointContext _context;
readonly TaskCompletionSource<bool> _deliveryComplete;
readonly IReceivePipeDispatcher _dispatcher;
readonly object _lock = new object();
readonly ConcurrentDictionary<TKey, BaseReceiveContext> _pending;
Task _consumeTask;
TaskCompletionSource<bool> _consumeTaskSource;

protected ConsumerAgent(ReceiveEndpointContext context)
protected ConsumerAgent(ReceiveEndpointContext context, IEqualityComparer<TKey> equalityComparer = default)
{
_context = context;
_deliveryComplete = TaskUtil.GetTask<bool>();

_pending = new ConcurrentDictionary<TKey, BaseReceiveContext>(equalityComparer);

_dispatcher = context.CreateReceivePipeDispatcher();
_dispatcher.ZeroActivity += HandleDeliveryComplete;
}
Expand Down Expand Up @@ -118,6 +123,15 @@ protected override Task StopAgent(StopContext context)
return Completed;
}

void CancelPendingConsumers()
{
foreach (var key in _pending.Keys)
{
if (_pending.TryRemove(key, out var context))
context.Cancel();
}
}

protected void TrySetConsumeCompleted()
{
_consumeTaskSource?.TrySetResult(true);
Expand All @@ -128,8 +142,7 @@ protected void TrySetConsumeCanceled(CancellationToken cancellationToken = defau
if (_consumeTaskSource == null)
return;

if (IsIdle)
_deliveryComplete.TrySetResult(false);
CancelPendingConsumers();

_consumeTaskSource.TrySetCanceled(cancellationToken);
}
Expand All @@ -139,7 +152,7 @@ protected void TrySetConsumeException(Exception exception)
if (_consumeTaskSource == null)
return;

_deliveryComplete.TrySetResult(false);
CancelPendingConsumers();

_consumeTaskSource.TrySetException(exception);
}
Expand All @@ -155,6 +168,8 @@ protected virtual async Task ActiveAndActualAgentsCompleted(StopContext context)
catch (OperationCanceledException)
{
LogContext.Warning?.Log("Stop canceled waiting for message consumers to complete: {InputAddress}", _context.InputAddress);

CancelPendingConsumers();
}
}

Expand All @@ -174,9 +189,27 @@ protected virtual async Task ActiveAndActualAgentsCompleted(StopContext context)
}
}

protected Task Dispatch(ReceiveContext context, ReceiveLockContext receiveLock = default)
protected virtual bool IsDuplicate(TKey key)
{
return true;
}

protected async Task Dispatch<TContext>(TKey key, TContext context, ReceiveLockContext receiveLock = default)
where TContext : BaseReceiveContext
{
return _dispatcher.Dispatch(context, receiveLock);
var added = _pending.TryAdd(key, context);
if (!added && IsDuplicate(key))
LogContext.Warning?.Log("Duplicate dispatch key {Key}", key);

try
{
await _dispatcher.Dispatch(context, receiveLock).ConfigureAwait(false);
}
finally
{
if (added)
_pending.TryRemove(key, out _);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace MassTransit.ActiveMqTransport.Middleware
/// Receives messages from ActiveMQ, pushing them to the InboundPipe of the service endpoint.
/// </summary>
public sealed class ActiveMqConsumer :
ConsumerAgent
ConsumerAgent<string>
{
readonly ActiveMqReceiveEndpointContext _context;
readonly ChannelExecutor _executor;
Expand All @@ -28,7 +28,7 @@ public sealed class ActiveMqConsumer :
/// <param name="context">The topology</param>
/// <param name="executor"></param>
public ActiveMqConsumer(SessionContext session, MessageConsumer messageConsumer, ActiveMqReceiveEndpointContext context, ChannelExecutor executor)
: base(context)
: base(context, StringComparer.Ordinal)
{
_session = session;
_messageConsumer = messageConsumer;
Expand All @@ -54,7 +54,7 @@ void HandleMessage(IMessage message)

try
{
await Dispatch(context, context).ConfigureAwait(false);
await Dispatch(message.NMSMessageId, context, context).ConfigureAwait(false);
}
catch (Exception exception)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace MassTransit.AmazonSqsTransport.Middleware
/// Receives messages from AmazonSQS, pushing them to the InboundPipe of the service endpoint.
/// </summary>
public sealed class AmazonSqsMessageReceiver :
ConsumerAgent
ConsumerAgent<string>
{
readonly ClientContext _client;
readonly SqsReceiveEndpointContext _context;
Expand All @@ -28,7 +28,7 @@ public sealed class AmazonSqsMessageReceiver :
/// <param name="client">The model context for the consumer</param>
/// <param name="context">The topology</param>
public AmazonSqsMessageReceiver(ClientContext client, SqsReceiveEndpointContext context)
: base(context)
: base(context, StringComparer.Ordinal)
{
_client = client;
_context = context;
Expand Down Expand Up @@ -99,7 +99,7 @@ async Task HandleMessage(Message message)
var context = new AmazonSqsReceiveContext(message, redelivered, _context, _client, _receiveSettings, _client.ConnectionContext);
try
{
await Dispatch(context, context).ConfigureAwait(false);
await Dispatch(message.MessageId, context, context).ConfigureAwait(false);
}
catch (Exception exception)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


public class Receiver :
ConsumerAgent,
ConsumerAgent<long>,
IReceiver
{
readonly ClientContext _clientContext;
Expand Down Expand Up @@ -136,7 +136,7 @@ protected async Task Dispatch(ServiceBusReceivedMessage message, ServiceBusRecei
try
{
var receiveLock = new ServiceBusReceiveLockContext(lockContext, context);
await Dispatch(context, receiveLock).ConfigureAwait(false);
await Dispatch(context.SequenceNumber, context, receiveLock).ConfigureAwait(false);
}
catch (ServiceBusException ex) when (ex.Reason == ServiceBusFailureReason.SessionLockLost)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
using System.Threading.Tasks;
using Azure.Messaging.EventHubs;
using Azure.Messaging.EventHubs.Processor;
using Checkpoints;
using MassTransit.Middleware;
using Transports;
using Util;


public class EventHubDataReceiver :
ConsumerAgent,
ConsumerAgent<PartitionOffset>,
IEventHubDataReceiver
{
readonly CancellationTokenSource _checkpointTokenSource;
Expand Down Expand Up @@ -75,7 +76,11 @@ async Task Handle(ProcessEventArgs eventArgs)

try
{
await Dispatch(context, context).ConfigureAwait(false);
await Dispatch(eventArgs, context, context).ConfigureAwait(false);
}
catch (Exception exception)
{
context.LogTransportFaulted(exception);
}
finally
{
Expand All @@ -99,6 +104,7 @@ protected override async Task ActiveAndActualAgentsCompleted(StopContext context
_client.ProcessEventAsync -= HandleMessage;
_client.ProcessErrorAsync -= HandleError;

await _lockContext.DisposeAsync().ConfigureAwait(false);
_checkpointTokenSource.Dispose();
}

Expand Down Expand Up @@ -128,10 +134,9 @@ public Task Run(ProcessEventArgs args, Func<Task> method, CancellationToken canc
return _partitionExecutorPool.Run(args, () => _keyExecutorPool.Run(args, method, cancellationToken), cancellationToken);
}

public async ValueTask DisposeAsync()
public ValueTask DisposeAsync()
{
await _partitionExecutorPool.DisposeAsync().ConfigureAwait(false);
await _keyExecutorPool.DisposeAsync().ConfigureAwait(false);
return _keyExecutorPool.DisposeAsync();
}

static byte[] GetBytes(ProcessEventArgs args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ namespace MassTransit.EventHubIntegration
using Azure.Messaging.EventHubs.Processor;


public interface IProcessorLockContext
public interface IProcessorLockContext :
IAsyncDisposable
{
Task Pending(ProcessEventArgs eventArgs);
Task Complete(ProcessEventArgs eventArgs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ public ProcessorLockContext(ProcessorContext context, ReceiveSettings receiveSet
_data = new SingleThreadedDictionary<string, PartitionCheckpointData>(StringComparer.Ordinal);
}

public ValueTask DisposeAsync()
{
return default;
}

public Task Push(ProcessEventArgs partition, Func<Task> method, CancellationToken cancellationToken = default)
{
return _data.TryGetValue(partition.Partition.PartitionId, out var data) ? data.Push(method) : Task.CompletedTask;
Expand All @@ -41,6 +36,12 @@ public Task Run(ProcessEventArgs partition, Func<Task> method, CancellationToken
return _data.TryGetValue(partition.Partition.PartitionId, out var data) ? data.Run(method) : Task.CompletedTask;
}

public ValueTask DisposeAsync()
{
_pending.Dispose();
return default;
}

public Task Pending(ProcessEventArgs eventArgs)
{
LogContext.SetCurrentIfNull(_context.LogContext);
Expand Down Expand Up @@ -122,7 +123,6 @@ public async Task Close(PartitionClosingEventArgs args)

LogContext.Info?.Log("Partition: {PartitionId} was closed, reason: {Reason}", args.PartitionId, args.Reason);

_pending.Dispose();
_cancellationTokenSource.Cancel();
_cancellationTokenSource.Dispose();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ public ConsumerLockContext(ConsumerContext context, ReceiveSettings receiveSetti
_data = new SingleThreadedDictionary<TopicPartition, PartitionCheckpointData>();
}

public ValueTask DisposeAsync()
{
return default;
}

public Task Push(ConsumeResult<byte[], byte[]> partition, Func<Task> method, CancellationToken cancellationToken = default)
{
return _data.TryGetValue(partition.TopicPartition, out var data) ? data.Push(method) : Task.CompletedTask;
Expand All @@ -43,6 +38,12 @@ public Task Run(ConsumeResult<byte[], byte[]> partition, Func<Task> method, Canc
return _data.TryGetValue(partition.TopicPartition, out var data) ? data.Run(method) : Task.CompletedTask;
}

public ValueTask DisposeAsync()
{
_pending.Dispose();
return default;
}

public Task Pending(ConsumeResult<byte[], byte[]> result)
{
LogContext.SetCurrentIfNull(_context.LogContext);
Expand Down Expand Up @@ -154,7 +155,6 @@ public async Task Close()
await _executor.DisposeAsync().ConfigureAwait(false);
await _checkpointer.DisposeAsync().ConfigureAwait(false);

_pending.Dispose();
_cancellationTokenSource.Cancel();
_cancellationTokenSource.Dispose();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ namespace MassTransit.KafkaIntegration
using Confluent.Kafka;


public interface IConsumerLockContext
public interface IConsumerLockContext :
IAsyncDisposable
{
Task Pending(ConsumeResult<byte[], byte[]> result);
Task Complete(ConsumeResult<byte[], byte[]> result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


public class KafkaMessageConsumer<TKey, TValue> :
ConsumerAgent,
ConsumerAgent<TopicPartitionOffset>,
IKafkaMessageConsumer<TKey, TValue>
where TValue : class
{
Expand Down Expand Up @@ -80,7 +80,7 @@ async Task Handle(ConsumeResult<byte[], byte[]> result)

try
{
await Dispatch(context, context).ConfigureAwait(false);
await Dispatch(result.TopicPartitionOffset, context, context).ConfigureAwait(false);
}
catch (Exception exception)
{
Expand Down Expand Up @@ -115,9 +115,10 @@ protected override async Task ActiveAndActualAgentsCompleted(StopContext context
_checkpointTokenSource.Cancel();

_consumer.Close();

_consumer.Dispose();
_cancellationTokenSource.Dispose();

await _lockContext.DisposeAsync().ConfigureAwait(false);
_checkpointTokenSource.Dispose();
}

Expand Down Expand Up @@ -147,10 +148,9 @@ public Task Run(ConsumeResult<byte[], byte[]> result, Func<Task> method, Cancell
return _partitionExecutorPool.Run(result, () => _keyExecutorPool.Run(result, method, cancellationToken), cancellationToken);
}

public async ValueTask DisposeAsync()
public ValueTask DisposeAsync()
{
await _partitionExecutorPool.DisposeAsync().ConfigureAwait(false);
await _keyExecutorPool.DisposeAsync().ConfigureAwait(false);
return _keyExecutorPool.DisposeAsync();
}
}
}
Expand Down
Loading

0 comments on commit f01d817

Please sign in to comment.