Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/Polly.Core/CircuitBreaker/CircuitBreakerResilienceStrategy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ public CircuitBreakerResilienceStrategy(

stateProvider?.Initialize(() => _controller.CircuitState);
_manualControlRegistration = manualControl?.Initialize(
async c => await _controller.IsolateCircuitAsync(c).ConfigureAwait(c.ContinueOnCapturedContext),
async c => await _controller.CloseCircuitAsync(c).ConfigureAwait(c.ContinueOnCapturedContext));
_controller.IsolateCircuitAsync,
_controller.CloseCircuitAsync);
}

public void Dispose()
Expand All @@ -34,7 +34,17 @@ protected internal override async ValueTask<Outcome<T>> ExecuteCore<TState>(Func
return outcome;
}

outcome = await StrategyHelper.ExecuteCallbackSafeAsync(callback, context, state).ConfigureAwait(context.ContinueOnCapturedContext);
try
{
context.CancellationToken.ThrowIfCancellationRequested();
outcome = await callback(context, state).ConfigureAwait(context.ContinueOnCapturedContext);
}
#pragma warning disable CA1031
catch (Exception ex)
{
outcome = new(ex);
}
#pragma warning restore CA1031

var args = new CircuitBreakerPredicateArguments<T>(context, outcome);
if (await _handler(args).ConfigureAwait(context.ContinueOnCapturedContext))
Expand Down
74 changes: 42 additions & 32 deletions src/Polly.Core/CircuitBreaker/Controller/CircuitStateController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public Outcome<T>? LastHandledOutcome
}
}

public ValueTask IsolateCircuitAsync(ResilienceContext context)
public Task IsolateCircuitAsync(ResilienceContext context)
{
EnsureNotDisposed();

Expand All @@ -98,14 +98,14 @@ public ValueTask IsolateCircuitAsync(ResilienceContext context)
var exception = new IsolatedCircuitException();
_telemetry.SetTelemetrySource(exception);
SetLastHandledOutcome_NeedsLock(Outcome.FromException<T>(exception));
OpenCircuitFor_NeedsLock(Outcome.FromResult<T>(default), TimeSpan.MaxValue, manual: true, context, out task);
task = OpenCircuitFor_NeedsLock(Outcome.FromResult<T>(default), TimeSpan.MaxValue, manual: true, context);
_circuitState = CircuitState.Isolated;
}

return ExecuteScheduledTaskAsync(task, context);
}

public ValueTask CloseCircuitAsync(ResilienceContext context)
public Task CloseCircuitAsync(ResilienceContext context)
{
EnsureNotDisposed();

Expand All @@ -115,7 +115,7 @@ public ValueTask CloseCircuitAsync(ResilienceContext context)

lock (_lock)
{
CloseCircuit_NeedsLock(Outcome.FromResult<T>(default), manual: true, context, out task);
task = CloseCircuit_NeedsLock(Outcome.FromResult<T>(default), manual: true, context);
}

return ExecuteScheduledTaskAsync(task, context);
Expand Down Expand Up @@ -166,7 +166,7 @@ public ValueTask CloseCircuitAsync(ResilienceContext context)
return null;
}

public ValueTask OnUnhandledOutcomeAsync(Outcome<T> outcome, ResilienceContext context)
public Task OnUnhandledOutcomeAsync(Outcome<T> outcome, ResilienceContext context)
{
EnsureNotDisposed();

Expand All @@ -184,15 +184,15 @@ public ValueTask OnUnhandledOutcomeAsync(Outcome<T> outcome, ResilienceContext c
// We take no special action; only time passing governs transitioning from Open to HalfOpen state.
if (_circuitState == CircuitState.HalfOpen)
{
CloseCircuit_NeedsLock(outcome, manual: false, context, out task);
task = CloseCircuit_NeedsLock(outcome, manual: false, context);
}

}

return ExecuteScheduledTaskAsync(task, context);
}

public ValueTask OnHandledOutcomeAsync(Outcome<T> outcome, ResilienceContext context)
public Task OnHandledOutcomeAsync(Outcome<T> outcome, ResilienceContext context)
{
EnsureNotDisposed();

Expand All @@ -214,7 +214,7 @@ public ValueTask OnHandledOutcomeAsync(Outcome<T> outcome, ResilienceContext con

if (_circuitState == CircuitState.HalfOpen || (_circuitState == CircuitState.Closed && shouldBreak))
{
OpenCircuit_NeedsLock(outcome, manual: false, context, out task);
task = OpenCircuitFor_NeedsLock(outcome, _breakDuration, manual: false, context);
}
}

Expand All @@ -227,22 +227,35 @@ public void Dispose()
_disposed = true;
}

internal static async ValueTask ExecuteScheduledTaskAsync(Task? task, ResilienceContext context)
internal static Task ExecuteScheduledTaskAsync(Task? task, ResilienceContext context)
{
if (task is not null)
{
if (context.IsSynchronous)
if (context.IsSynchronous && !task.IsCompleted)
{
#pragma warning disable CA1849 // Call async methods when in an async method
// because this is synchronous execution we need to block
task.GetAwaiter().GetResult();
#if NET8_0_OR_GREATER
task.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing).GetAwaiter().GetResult();
#else
try
{
task.GetAwaiter().GetResult();
}
#pragma warning disable CA1031
catch
{
// exception will be observed by the awaiter of this method
}
#pragma warning restore CA1031
#endif
#pragma warning restore CA1849 // Call async methods when in an async method
}
else
{
await task.ConfigureAwait(context.ContinueOnCapturedContext);
}

return task;
}

return Task.CompletedTask;
}

private static bool IsDateTimeOverflow(DateTimeOffset utcNow, TimeSpan breakDuration)
Expand All @@ -266,10 +279,8 @@ private void EnsureNotDisposed()
}
#endif

private void CloseCircuit_NeedsLock(Outcome<T> outcome, bool manual, ResilienceContext context, out Task? scheduledTask)
private Task? CloseCircuit_NeedsLock(Outcome<T> outcome, bool manual, ResilienceContext context)
{
scheduledTask = null;

_blockedUntil = DateTimeOffset.MinValue;
_lastOutcome = null;
_halfOpenAttempts = 0;
Expand All @@ -285,9 +296,13 @@ private void CloseCircuit_NeedsLock(Outcome<T> outcome, bool manual, ResilienceC

if (_onClosed is not null)
{
_executor.ScheduleTask(() => _onClosed(args).AsTask(), context, out scheduledTask);
return _executor.ScheduleTask(() => _onClosed(args).AsTask());
}
}

#pragma warning disable S4586
return null;
#pragma warning restore S4586
}

private bool PermitHalfOpenCircuitTest_NeedsLock()
Expand All @@ -311,21 +326,13 @@ private void SetLastHandledOutcome_NeedsLock(Outcome<T> outcome)
private BrokenCircuitException CreateBrokenCircuitException()
{
TimeSpan retryAfter = _blockedUntil - _timeProvider.GetUtcNow();
var exception = _breakingException switch
{
Exception ex => new BrokenCircuitException(BrokenCircuitException.DefaultMessage, retryAfter, ex),
_ => new BrokenCircuitException(BrokenCircuitException.DefaultMessage, retryAfter)
};
var exception = new BrokenCircuitException(BrokenCircuitException.DefaultMessage, retryAfter, _breakingException!);
_telemetry.SetTelemetrySource(exception);
return exception;
}

private void OpenCircuit_NeedsLock(Outcome<T> outcome, bool manual, ResilienceContext context, out Task? scheduledTask)
=> OpenCircuitFor_NeedsLock(outcome, _breakDuration, manual, context, out scheduledTask);

private void OpenCircuitFor_NeedsLock(Outcome<T> outcome, TimeSpan breakDuration, bool manual, ResilienceContext context, out Task? scheduledTask)
private Task? OpenCircuitFor_NeedsLock(Outcome<T> outcome, TimeSpan breakDuration, bool manual, ResilienceContext context)
{
scheduledTask = null;
var utcNow = _timeProvider.GetUtcNow();

if (_breakDurationGenerator is not null)
Expand All @@ -345,14 +352,17 @@ private void OpenCircuitFor_NeedsLock(Outcome<T> outcome, TimeSpan breakDuration

if (_onOpened is not null)
{
_executor.ScheduleTask(() => _onOpened(args).AsTask(), context, out scheduledTask);
return _executor.ScheduleTask(() => _onOpened(args).AsTask());
}

#pragma warning disable S4586
return null;
#pragma warning restore S4586
}

private Task ScheduleHalfOpenTask(ResilienceContext context)
{
_executor.ScheduleTask(() => _onHalfOpen!(new OnCircuitHalfOpenedArguments(context)).AsTask(), context, out var task);
return task;
return _executor.ScheduleTask(() => _onHalfOpen!(new OnCircuitHalfOpenedArguments(context)).AsTask());
}
}

27 changes: 10 additions & 17 deletions src/Polly.Core/CircuitBreaker/Controller/ScheduledTaskExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ internal sealed class ScheduledTaskExecutor : IDisposable

public Task ProcessingTask { get; }

public void ScheduleTask(Func<Task> taskFactory, ResilienceContext context, out Task task)
public Task ScheduleTask(Func<Task> taskFactory)
{
#if NET8_0_OR_GREATER
ObjectDisposedException.ThrowIf(_disposed, this);
Expand All @@ -27,10 +27,10 @@ public void ScheduleTask(Func<Task> taskFactory, ResilienceContext context, out
#endif

var source = new TaskCompletionSource<object>();
task = source.Task;

_tasks.Enqueue(new Entry(taskFactory, context.ContinueOnCapturedContext, source));
_tasks.Enqueue(new Entry(taskFactory, source));
_semaphore.Release();
return source.Task;
}

public void Dispose()
Expand All @@ -53,36 +53,29 @@ public void Dispose()

private async Task StartProcessingAsync()
{
while (true)
while (!_disposed)
{
await _semaphore.WaitAsync().ConfigureAwait(false);
if (_disposed)
if (_disposed || !_tasks.TryDequeue(out var entry))
{
return;
}

_ = _tasks.TryDequeue(out var entry);

try
{
await entry!.TaskFactory().ConfigureAwait(entry.ContinueOnCapturedContext);
entry.TaskCompletion.SetResult(null!);
await entry.TaskFactory().ConfigureAwait(false);
entry.TaskCompletion.TrySetResult(null!);
}
catch (OperationCanceledException)
{
entry!.TaskCompletion.SetCanceled();
entry.TaskCompletion.TrySetCanceled();
}
catch (Exception e)
{
entry!.TaskCompletion.SetException(e);
}

if (_disposed)
{
return;
entry.TaskCompletion.TrySetException(e);
}
}
}

private sealed record Entry(Func<Task> TaskFactory, bool ContinueOnCapturedContext, TaskCompletionSource<object> TaskCompletion);
private sealed record Entry(Func<Task> TaskFactory, TaskCompletionSource<object> TaskCompletion);
}
6 changes: 1 addition & 5 deletions src/Polly.Core/Fallback/FallbackHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,4 @@ namespace Polly.Fallback;

internal sealed record class FallbackHandler<T>(
Func<FallbackPredicateArguments<T>, ValueTask<bool>> ShouldHandle,
Func<FallbackActionArguments<T>, ValueTask<Outcome<T>>> ActionGenerator)
{
public ValueTask<Outcome<T>> GetFallbackOutcomeAsync(FallbackActionArguments<T> args) => ActionGenerator(args);
}

Func<FallbackActionArguments<T>, ValueTask<Outcome<T>>> ActionGenerator);
13 changes: 11 additions & 2 deletions src/Polly.Core/Fallback/FallbackResilienceStrategy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@ public FallbackResilienceStrategy(FallbackHandler<T> handler, Func<OnFallbackArg

protected internal override async ValueTask<Outcome<T>> ExecuteCore<TState>(Func<ResilienceContext, TState, ValueTask<Outcome<T>>> callback, ResilienceContext context, TState state)
{
var outcome = await StrategyHelper.ExecuteCallbackSafeAsync(callback, context, state).ConfigureAwait(context.ContinueOnCapturedContext);
Outcome<T> outcome;
try
{
outcome = await callback(context, state).ConfigureAwait(context.ContinueOnCapturedContext);
}
catch (Exception ex)
{
outcome = new(ex);
}

var handleFallbackArgs = new FallbackPredicateArguments<T>(context, outcome);
if (!await _handler.ShouldHandle(handleFallbackArgs).ConfigureAwait(context.ContinueOnCapturedContext))
{
Expand All @@ -37,7 +46,7 @@ protected internal override async ValueTask<Outcome<T>> ExecuteCore<TState>(Func

try
{
return await _handler.GetFallbackOutcomeAsync(new FallbackActionArguments<T>(context, outcome)).ConfigureAwait(context.ContinueOnCapturedContext);
return await _handler.ActionGenerator(new FallbackActionArguments<T>(context, outcome)).ConfigureAwait(context.ContinueOnCapturedContext);
}
catch (Exception e)
{
Expand Down
6 changes: 3 additions & 3 deletions src/Polly.Core/Hedging/Controller/HedgingController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ public HedgingController(
return true;
});

Action<HedgingExecutionContext<T>> onReset = null!;
_contextPool = new ObjectPool<HedgingExecutionContext<T>>(
() =>
{
Interlocked.Increment(ref _rentedContexts);
return new HedgingExecutionContext<T>(_executionPool, provider, maxAttempts, ReturnContext);
return new HedgingExecutionContext<T>(_executionPool, provider, maxAttempts, onReset);
},
_ =>
{
Expand All @@ -45,6 +46,7 @@ public HedgingController(
// Stryker disable once Boolean : no means to test this
return true;
});
onReset = _contextPool.Return;
}

public int RentedContexts => _rentedContexts;
Expand All @@ -57,6 +59,4 @@ public HedgingExecutionContext<T> GetContext(ResilienceContext context)
executionContext.Initialize(context);
return executionContext;
}

private void ReturnContext(HedgingExecutionContext<T> context) => _contextPool.Return(context);
}
Loading
Loading