Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Npgsql/BackendMessages/ErrorOrNoticeMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ internal ErrorOrNoticeMessage(
/// <summary>
/// Error and notice message field codes
/// </summary>
enum ErrorFieldTypeCode : byte
internal enum ErrorFieldTypeCode : byte
{
Done = 0,
Severity = (byte)'S',
Expand Down
2 changes: 1 addition & 1 deletion src/Npgsql/NpgsqlCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,7 @@ internal async ValueTask<NpgsqlDataReader> ExecuteReader(CommandBehavior behavio
if (conn.TryGetBoundConnector(out var connector))
{
connector.StartUserAction(this);
connector.ResetCancellation();
connector.ResetCancellation(cancellationToken);

CancellationTokenRegistration? registration = null;

Expand Down
25 changes: 18 additions & 7 deletions src/Npgsql/NpgsqlConnector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ internal void FlagAsWritableForMultiplexing()
internal int ClearCounter { get; set; }

volatile bool _cancellationRequested;

volatile bool _userCancellationRequested;
internal CancellationToken UserCancellationToken { get; set; }

static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(NpgsqlConnector));

Expand Down Expand Up @@ -1065,7 +1065,7 @@ internal ValueTask<IBackendMessage> ReadMessage(bool async, DataRowLoadingMode d

while (true)
{
await ReadBuffer.Ensure(5, async, readingNotifications2, cancellationToken);
await ReadBuffer.Ensure(5, async, readingNotifications2, cancellationToken2);
messageCode = (BackendMessageCode)ReadBuffer.ReadByte();
PGUtil.ValidateBackendMessageCode(messageCode);
len = ReadBuffer.ReadInt32() - 4; // Transmitted length includes itself
Expand All @@ -1076,7 +1076,7 @@ internal ValueTask<IBackendMessage> ReadMessage(bool async, DataRowLoadingMode d
{
if (dataRowLoadingMode2 == DataRowLoadingMode.Skip)
{
await ReadBuffer.Skip(len, async, cancellationToken);
await ReadBuffer.Skip(len, async, cancellationToken2);
continue;
}
}
Expand All @@ -1094,7 +1094,7 @@ internal ValueTask<IBackendMessage> ReadMessage(bool async, DataRowLoadingMode d
ReadBuffer = oversizeBuffer;
}

await ReadBuffer.Ensure(len, async, cancellationToken);
await ReadBuffer.Ensure(len, async, cancellationToken2);
}

var msg = ParseServerMessage(ReadBuffer, messageCode, len, isReadingPrependedMessage);
Expand Down Expand Up @@ -1157,7 +1157,9 @@ internal ValueTask<IBackendMessage> ReadMessage(bool async, DataRowLoadingMode d
{
// User requested the cancellation - translate the PostgresException to an OperationCanceledException (keeping the former as the inner)
if (_userCancellationRequested)
throw new OperationCanceledException("Query was cancelled", e, cancellationToken2);
{
throw new OperationCanceledException("Query was cancelled", e, UserCancellationToken);
}

// We've timed out, send the cancellation request and successfully read it
throw new NpgsqlException("Exception while reading from stream", new TimeoutException("Timeout during reading attempt"));
Expand All @@ -1168,7 +1170,7 @@ internal ValueTask<IBackendMessage> ReadMessage(bool async, DataRowLoadingMode d
catch (NpgsqlException e) when (e.InnerException is TimeoutException && _userCancellationRequested)
{
// User requested the cancellation and it timed out
throw new OperationCanceledException("Query was cancelled", e.InnerException, cancellationToken2);
throw new OperationCanceledException("Query was cancelled", e.InnerException, UserCancellationToken);
}
catch (NpgsqlException)
{
Expand Down Expand Up @@ -1485,10 +1487,19 @@ void DoCancelRequest(int backendProcessId, int backendSecretKey)
}
}

internal void ResetCancellation()
/// <summary>
/// Resets cancellation-related state to prepare for a new cancellable operation.
/// </summary>
/// <param name="userCancellationToken">
/// The cancellation token provided by the user for the operation. This is stored on the connector, and will be referenced
/// by the <see cref="OperationCanceledException"/> if one is thrown.
/// </param>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal void ResetCancellation(CancellationToken userCancellationToken)
{
_cancellationRequested = false;
_userCancellationRequested = false;
UserCancellationToken = userCancellationToken;
ReadBuffer.Cts.ResetCts();
}

Expand Down
9 changes: 9 additions & 0 deletions src/Npgsql/NpgsqlDataReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,10 @@ async Task<bool> Read(bool async, CancellationToken cancellationToken = default)
CancellationTokenRegistration? registration = null;

if (cancellationToken.CanBeCanceled)
{
Connector.UserCancellationToken = cancellationToken;
registration = cancellationToken.Register(reader => ((NpgsqlDataReader)reader!).Cancel(), this);
}

try
{
Expand Down Expand Up @@ -330,7 +333,10 @@ async Task<bool> NextResult(bool async, bool isConsuming = false, CancellationTo
CancellationTokenRegistration? registration = null;

if (cancellationToken.CanBeCanceled)
{
Connector.UserCancellationToken = cancellationToken;
registration = cancellationToken.Register(reader => ((NpgsqlDataReader)reader!).Cancel(), this);
}

try
{
Expand Down Expand Up @@ -568,7 +574,10 @@ async Task<bool> NextResultSchemaOnly(bool async, CancellationToken cancellation
CancellationTokenRegistration? registration = null;

if (cancellationToken.CanBeCanceled)
{
Connector.UserCancellationToken = cancellationToken;
registration = cancellationToken.Register(reader => ((NpgsqlDataReader)reader!).Cancel(), this);
}

try
{
Expand Down
22 changes: 14 additions & 8 deletions test/Npgsql.Tests/CommandTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ public async Task TimeoutAsyncHard()
await using var postmasterMock = PgPostmasterMock.Start(builder.ConnectionString);
using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString);
await using var conn = await OpenConnectionAsync(connectionString);
await postmasterMock.WaitForServerConnection();

var processId = conn.ProcessID;

Expand All @@ -206,7 +207,7 @@ public async Task TimeoutAsyncHard()
.With.InnerException.TypeOf<TimeoutException>());

Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken));
Assert.That(postmasterMock.GetPendingCancellationRequest().ProcessId,
Assert.That((await postmasterMock.WaitForCancellationRequest()).ProcessId,
Is.EqualTo(processId));
}

Expand Down Expand Up @@ -299,11 +300,14 @@ public async Task CancelAsyncSoft()
var cancellationSource = new CancellationTokenSource();
var t = cmd.ExecuteNonQueryAsync(cancellationSource.Token);
cancellationSource.Cancel();
Assert.That(async () => await t, Throws.Exception.TypeOf<OperationCanceledException>()
.With.InnerException.TypeOf<PostgresException>()
.With.InnerException.Property(nameof(PostgresException.SqlState)).EqualTo("57014"));

var exception = Assert.ThrowsAsync<OperationCanceledException>(async () => await t);
Assert.That(exception.InnerException,
Is.TypeOf<PostgresException>().With.Property(nameof(PostgresException.SqlState)).EqualTo("57014"));
Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token));

Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open));
Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1));
}

[Test, Description("Cancels an async query with the cancellation token, with unsuccessful PG cancellation (socket break)")]
Expand All @@ -315,19 +319,21 @@ public async Task CancelAsyncHard()
await using var postmasterMock = PgPostmasterMock.Start(ConnectionString);
using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString);
await using var conn = await OpenConnectionAsync(connectionString);
await postmasterMock.WaitForServerConnection();

var processId = conn.ProcessID;

var cancellationSource = new CancellationTokenSource();
using var cmd = new NpgsqlCommand("SELECT 1", conn);
var t = cmd.ExecuteScalarAsync(cancellationSource.Token);
cancellationSource.Cancel();
Assert.That(async () => await t, Throws.Exception
.TypeOf<OperationCanceledException>()
.With.InnerException.TypeOf<TimeoutException>());

var exception = Assert.ThrowsAsync<OperationCanceledException>(async () => await t);
Assert.That(exception.InnerException, Is.TypeOf<TimeoutException>());
Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token));

Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken));
Assert.That(postmasterMock.GetPendingCancellationRequest().ProcessId,
Assert.That((await postmasterMock.WaitForCancellationRequest()).ProcessId,
Is.EqualTo(processId));
}

Expand Down
Loading