diff --git a/src/Npgsql/BackendMessages/ErrorOrNoticeMessage.cs b/src/Npgsql/BackendMessages/ErrorOrNoticeMessage.cs index 2c675190c6..56249042a2 100644 --- a/src/Npgsql/BackendMessages/ErrorOrNoticeMessage.cs +++ b/src/Npgsql/BackendMessages/ErrorOrNoticeMessage.cs @@ -162,7 +162,7 @@ internal ErrorOrNoticeMessage( /// /// Error and notice message field codes /// - enum ErrorFieldTypeCode : byte + internal enum ErrorFieldTypeCode : byte { Done = 0, Severity = (byte)'S', diff --git a/src/Npgsql/NpgsqlCommand.cs b/src/Npgsql/NpgsqlCommand.cs index a10ca192bd..52175f5444 100644 --- a/src/Npgsql/NpgsqlCommand.cs +++ b/src/Npgsql/NpgsqlCommand.cs @@ -1153,7 +1153,7 @@ internal async ValueTask ExecuteReader(CommandBehavior behavio if (conn.TryGetBoundConnector(out var connector)) { connector.StartUserAction(this); - connector.ResetCancellation(); + connector.ResetCancellation(cancellationToken); CancellationTokenRegistration? registration = null; diff --git a/src/Npgsql/NpgsqlConnector.cs b/src/Npgsql/NpgsqlConnector.cs index 945036b08b..b8ea045aa7 100644 --- a/src/Npgsql/NpgsqlConnector.cs +++ b/src/Npgsql/NpgsqlConnector.cs @@ -222,8 +222,8 @@ internal void FlagAsWritableForMultiplexing() internal int ClearCounter { get; set; } volatile bool _cancellationRequested; - volatile bool _userCancellationRequested; + internal CancellationToken UserCancellationToken { get; set; } static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(NpgsqlConnector)); @@ -1065,7 +1065,7 @@ internal ValueTask ReadMessage(bool async, DataRowLoadingMode d while (true) { - await ReadBuffer.Ensure(5, async, readingNotifications2, cancellationToken); + await ReadBuffer.Ensure(5, async, readingNotifications2, cancellationToken2); messageCode = (BackendMessageCode)ReadBuffer.ReadByte(); PGUtil.ValidateBackendMessageCode(messageCode); len = ReadBuffer.ReadInt32() - 4; // Transmitted length includes itself @@ -1076,7 +1076,7 @@ internal ValueTask ReadMessage(bool async, DataRowLoadingMode d { if (dataRowLoadingMode2 == DataRowLoadingMode.Skip) { - await ReadBuffer.Skip(len, async, cancellationToken); + await ReadBuffer.Skip(len, async, cancellationToken2); continue; } } @@ -1094,7 +1094,7 @@ internal ValueTask ReadMessage(bool async, DataRowLoadingMode d ReadBuffer = oversizeBuffer; } - await ReadBuffer.Ensure(len, async, cancellationToken); + await ReadBuffer.Ensure(len, async, cancellationToken2); } var msg = ParseServerMessage(ReadBuffer, messageCode, len, isReadingPrependedMessage); @@ -1157,7 +1157,9 @@ internal ValueTask ReadMessage(bool async, DataRowLoadingMode d { // User requested the cancellation - translate the PostgresException to an OperationCanceledException (keeping the former as the inner) if (_userCancellationRequested) - throw new OperationCanceledException("Query was cancelled", e, cancellationToken2); + { + throw new OperationCanceledException("Query was cancelled", e, UserCancellationToken); + } // We've timed out, send the cancellation request and successfully read it throw new NpgsqlException("Exception while reading from stream", new TimeoutException("Timeout during reading attempt")); @@ -1168,7 +1170,7 @@ internal ValueTask ReadMessage(bool async, DataRowLoadingMode d catch (NpgsqlException e) when (e.InnerException is TimeoutException && _userCancellationRequested) { // User requested the cancellation and it timed out - throw new OperationCanceledException("Query was cancelled", e.InnerException, cancellationToken2); + throw new OperationCanceledException("Query was cancelled", e.InnerException, UserCancellationToken); } catch (NpgsqlException) { @@ -1485,10 +1487,19 @@ void DoCancelRequest(int backendProcessId, int backendSecretKey) } } - internal void ResetCancellation() + /// + /// Resets cancellation-related state to prepare for a new cancellable operation. + /// + /// + /// The cancellation token provided by the user for the operation. This is stored on the connector, and will be referenced + /// by the if one is thrown. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ResetCancellation(CancellationToken userCancellationToken) { _cancellationRequested = false; _userCancellationRequested = false; + UserCancellationToken = userCancellationToken; ReadBuffer.Cts.ResetCts(); } diff --git a/src/Npgsql/NpgsqlDataReader.cs b/src/Npgsql/NpgsqlDataReader.cs index 629f1ce8b1..e14822bd6d 100644 --- a/src/Npgsql/NpgsqlDataReader.cs +++ b/src/Npgsql/NpgsqlDataReader.cs @@ -230,7 +230,10 @@ async Task Read(bool async, CancellationToken cancellationToken = default) CancellationTokenRegistration? registration = null; if (cancellationToken.CanBeCanceled) + { + Connector.UserCancellationToken = cancellationToken; registration = cancellationToken.Register(reader => ((NpgsqlDataReader)reader!).Cancel(), this); + } try { @@ -330,7 +333,10 @@ async Task NextResult(bool async, bool isConsuming = false, CancellationTo CancellationTokenRegistration? registration = null; if (cancellationToken.CanBeCanceled) + { + Connector.UserCancellationToken = cancellationToken; registration = cancellationToken.Register(reader => ((NpgsqlDataReader)reader!).Cancel(), this); + } try { @@ -568,7 +574,10 @@ async Task NextResultSchemaOnly(bool async, CancellationToken cancellation CancellationTokenRegistration? registration = null; if (cancellationToken.CanBeCanceled) + { + Connector.UserCancellationToken = cancellationToken; registration = cancellationToken.Register(reader => ((NpgsqlDataReader)reader!).Cancel(), this); + } try { diff --git a/test/Npgsql.Tests/CommandTests.cs b/test/Npgsql.Tests/CommandTests.cs index 85d24c1efc..4cceed3105 100644 --- a/test/Npgsql.Tests/CommandTests.cs +++ b/test/Npgsql.Tests/CommandTests.cs @@ -197,6 +197,7 @@ public async Task TimeoutAsyncHard() await using var postmasterMock = PgPostmasterMock.Start(builder.ConnectionString); using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); await using var conn = await OpenConnectionAsync(connectionString); + await postmasterMock.WaitForServerConnection(); var processId = conn.ProcessID; @@ -206,7 +207,7 @@ public async Task TimeoutAsyncHard() .With.InnerException.TypeOf()); Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); - Assert.That(postmasterMock.GetPendingCancellationRequest().ProcessId, + Assert.That((await postmasterMock.WaitForCancellationRequest()).ProcessId, Is.EqualTo(processId)); } @@ -299,11 +300,14 @@ public async Task CancelAsyncSoft() var cancellationSource = new CancellationTokenSource(); var t = cmd.ExecuteNonQueryAsync(cancellationSource.Token); cancellationSource.Cancel(); - Assert.That(async () => await t, Throws.Exception.TypeOf() - .With.InnerException.TypeOf() - .With.InnerException.Property(nameof(PostgresException.SqlState)).EqualTo("57014")); + + var exception = Assert.ThrowsAsync(async () => await t); + Assert.That(exception.InnerException, + Is.TypeOf().With.Property(nameof(PostgresException.SqlState)).EqualTo("57014")); + Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); } [Test, Description("Cancels an async query with the cancellation token, with unsuccessful PG cancellation (socket break)")] @@ -315,6 +319,7 @@ public async Task CancelAsyncHard() await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); await using var conn = await OpenConnectionAsync(connectionString); + await postmasterMock.WaitForServerConnection(); var processId = conn.ProcessID; @@ -322,12 +327,13 @@ public async Task CancelAsyncHard() using var cmd = new NpgsqlCommand("SELECT 1", conn); var t = cmd.ExecuteScalarAsync(cancellationSource.Token); cancellationSource.Cancel(); - Assert.That(async () => await t, Throws.Exception - .TypeOf() - .With.InnerException.TypeOf()); + + var exception = Assert.ThrowsAsync(async () => await t); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); - Assert.That(postmasterMock.GetPendingCancellationRequest().ProcessId, + Assert.That((await postmasterMock.WaitForCancellationRequest()).ProcessId, Is.EqualTo(processId)); } diff --git a/test/Npgsql.Tests/ReaderTests.cs b/test/Npgsql.Tests/ReaderTests.cs index cfe2814418..403fe183cb 100644 --- a/test/Npgsql.Tests/ReaderTests.cs +++ b/test/Npgsql.Tests/ReaderTests.cs @@ -6,9 +6,11 @@ using System.Linq; using System.Runtime.InteropServices; using System.Text; +using System.Threading; using System.Threading.Tasks; using Npgsql.BackendMessages; using Npgsql.PostgresTypes; +using Npgsql.Tests.Support; using Npgsql.TypeHandling; using Npgsql.TypeMapping; using NpgsqlTypes; @@ -1532,6 +1534,198 @@ public async Task NonSafeReadException() } #endif + #region Cancellation + + [Test, Description("Cancels ReadAsync via the cancellation token, with successful PG cancellation")] + public async Task ReadAsync_cancel_soft() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation + + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); + await using var conn = await OpenConnectionAsync(connectionString); + + // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) + .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) + .FlushAsync(); + + using var cmd = new NpgsqlCommand("SELECT some_int FROM some_table", conn); + await using (var reader = await cmd.ExecuteReaderAsync()) + { + // Successfully read the first row + Assert.True(await reader.ReadAsync()); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + + // Attempt to read the second row - simulate blocking and cancellation + var cancellationSource = new CancellationTokenSource(); + var task = reader.ReadAsync(cancellationSource.Token); + cancellationSource.Cancel(); + + var (processId, _) = await postmasterMock.WaitForCancellationRequest(); + Assert.That(processId, Is.EqualTo(conn.ProcessID)); + + await pgMock + .WriteErrorResponse(PostgresErrorCodes.QueryCanceled) + .WriteReadyForQuery() + .FlushAsync(); + + var exception = Assert.ThrowsAsync(async () => await task); + Assert.That(exception.InnerException, + Is.TypeOf().With.Property(nameof(PostgresException.SqlState)).EqualTo("57014")); + Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); + + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + } + + await pgMock.WriteScalarResponse(1); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test, Description("Cancels NextResultAsync via the cancellation token, with successful PG cancellation")] + public async Task NextResult_cancel_soft() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation + + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); + await using var conn = await OpenConnectionAsync(connectionString); + + // Write responses to the query we're about to send, only for the first resultset (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) + .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) + .WriteCommandComplete() + .FlushAsync(); + + using var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn); + await using (var reader = await cmd.ExecuteReaderAsync()) + { + // Successfully read the first resultset + Assert.True(await reader.ReadAsync()); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + + // Attempt to advance to the second resultset - simulate blocking and cancellation + var cancellationSource = new CancellationTokenSource(); + var task = reader.NextResultAsync(cancellationSource.Token); + cancellationSource.Cancel(); + + var (processId, _) = await postmasterMock.WaitForCancellationRequest(); + Assert.That(processId, Is.EqualTo(conn.ProcessID)); + + await pgMock + .WriteErrorResponse(PostgresErrorCodes.QueryCanceled) + .WriteReadyForQuery() + .FlushAsync(); + + var exception = Assert.ThrowsAsync(async () => await task); + Assert.That(exception.InnerException, + Is.TypeOf().With.Property(nameof(PostgresException.SqlState)).EqualTo("57014")); + Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); + + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + } + + await pgMock.WriteScalarResponse(1); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test, Description("Cancels ReadAsync via the cancellation token, with unsuccessful PG cancellation (socket break)")] + public async Task ReadAsync_cancel_hard() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation + + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); + await using var conn = await OpenConnectionAsync(connectionString); + + // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) + .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) + .FlushAsync(); + + using var cmd = new NpgsqlCommand("SELECT some_int FROM some_table", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + + // Successfully read the first row + Assert.True(await reader.ReadAsync()); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + + // Attempt to read the second row - simulate blocking and cancellation + var cancellationSource = new CancellationTokenSource(); + var task = reader.ReadAsync(cancellationSource.Token); + cancellationSource.Cancel(); + + var (processId, _) = await postmasterMock.WaitForCancellationRequest(); + Assert.That(processId, Is.EqualTo(conn.ProcessID)); + + // Send no response from server, wait for the cancellation attempt to time out + var exception = Assert.ThrowsAsync(async () => await task); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); + + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } + + [Test, Description("Cancels ReadAsync via the cancellation token, with unsuccessful PG cancellation (socket break)")] + public async Task NextResultAsync_cancel_hard() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation + + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); + await using var conn = await OpenConnectionAsync(connectionString); + + // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) + .WriteDataRow(BitConverter.GetBytes(BinaryPrimitives.ReverseEndianness(1))) + .WriteCommandComplete() + .FlushAsync(); + + using var cmd = new NpgsqlCommand("SELECT some_int FROM some_table", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + + // Successfully read the first resultset + Assert.True(await reader.ReadAsync()); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + + // Attempt to read the second row - simulate blocking and cancellation + var cancellationSource = new CancellationTokenSource(); + var task = reader.NextResultAsync(cancellationSource.Token); + cancellationSource.Cancel(); + + var (processId, _) = await postmasterMock.WaitForCancellationRequest(); + Assert.That(processId, Is.EqualTo(conn.ProcessID)); + + // Send no response from server, wait for the cancellation attempt to time out + var exception = Assert.ThrowsAsync(async () => await task); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); + + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } + + #endregion Cancellation + #region Initialization / setup / teardown // ReSharper disable InconsistentNaming diff --git a/test/Npgsql.Tests/Support/PgPostmasterMock.cs b/test/Npgsql.Tests/Support/PgPostmasterMock.cs index 1c88bbdf74..1c93716338 100644 --- a/test/Npgsql.Tests/Support/PgPostmasterMock.cs +++ b/test/Npgsql.Tests/Support/PgPostmasterMock.cs @@ -4,6 +4,7 @@ using System.Net; using System.Net.Sockets; using System.Text; +using System.Threading.Channels; using System.Threading.Tasks; using Npgsql.Util; using NUnit.Framework.Constraints; @@ -21,12 +22,13 @@ class PgPostmasterMock : IAsyncDisposable readonly Socket _socket; readonly List _allServers = new List(); - readonly Queue _pendingServers = new Queue(); - readonly Queue<(int ProcessId, int Secret)> _pendingCancellationRequests - = new Queue<(int ProcessId, int Secret)>(); + bool _acceptingClients; Task? _acceptClientsTask; int _processIdCounter; + ChannelWriter _pendingRequestsWriter { get; } + internal ChannelReader PendingRequestsReader { get; } + internal string ConnectionString { get; } internal static PgPostmasterMock Start(string? connectionString = null) @@ -38,6 +40,10 @@ internal static PgPostmasterMock Start(string? connectionString = null) internal PgPostmasterMock(string? connectionString = null) { + var pendingRequestsChannel = Channel.CreateUnbounded(); + PendingRequestsReader = pendingRequestsChannel.Reader; + _pendingRequestsWriter = pendingRequestsChannel.Writer; + var connectionStringBuilder = new NpgsqlConnectionStringBuilder(connectionString ?? TestUtil.ConnectionString); @@ -55,6 +61,7 @@ internal PgPostmasterMock(string? connectionString = null) void AcceptClients() { + _acceptingClients = true; _acceptClientsTask = DoAcceptClients(); async Task DoAcceptClients() @@ -64,11 +71,14 @@ async Task DoAcceptClients() var serverOrCancellationRequest = await Accept(); if (serverOrCancellationRequest.Server is { } server) { - _pendingServers.Enqueue(server); - await server.Startup(); + // Hand off the new server to the client test only once startup is complete, to avoid reading/writing in parallel + // during startup. Don't wait for all this to complete - continue to accept other connections in case that's needed. + _ = server.Startup().ContinueWith(t => _pendingRequestsWriter.WriteAsync(serverOrCancellationRequest)); } else - _pendingCancellationRequests.Enqueue(serverOrCancellationRequest.CancellationRequest!.Value); + { + await _pendingRequestsWriter.WriteAsync(serverOrCancellationRequest); + } } // ReSharper disable once FunctionNeverReturns @@ -105,21 +115,39 @@ internal async Task Accept() internal async Task AcceptServer() { + if (_acceptingClients) + throw new InvalidOperationException($"Already accepting clients via {nameof(AcceptClients)}"); var serverOrCancellationRequest = await Accept(); if (serverOrCancellationRequest.Server is null) - throw new InvalidOperationException("Expected new server connection but a cancellation request occurred instead"); + throw new InvalidOperationException("Expected a server connection but got a cancellation request instead"); return serverOrCancellationRequest.Server; } - internal PgServerMock GetPendingServer() - => _pendingServers.TryDequeue(out var server) - ? server - : throw new InvalidOperationException("No pending server"); + internal async Task<(int ProcessId, int Secret)> AcceptCancellationRequest() + { + if (_acceptingClients) + throw new InvalidOperationException($"Already accepting clients via {nameof(AcceptClients)}"); + var serverOrCancellationRequest = await Accept(); + if (serverOrCancellationRequest.CancellationRequest is null) + throw new InvalidOperationException("Expected a cancellation request but got a server connection instead"); + return serverOrCancellationRequest.CancellationRequest.Value; + } + + internal async ValueTask WaitForServerConnection() + { + var serverOrCancellationRequest = await PendingRequestsReader.ReadAsync(); + if (serverOrCancellationRequest.Server is null) + throw new InvalidOperationException("Expected a server connection but got a cancellation request instead"); + return serverOrCancellationRequest.Server; + } - internal (int ProcessId, int Secret) GetPendingCancellationRequest() - => _pendingCancellationRequests.TryDequeue(out var cancellationRequest) - ? cancellationRequest - : throw new InvalidOperationException("No pending cancellation request"); + internal async ValueTask<(int ProcessId, int Secret)> WaitForCancellationRequest() + { + var serverOrCancellationRequest = await PendingRequestsReader.ReadAsync(); + if (serverOrCancellationRequest.CancellationRequest is null) + throw new InvalidOperationException("Expected cancellation request but got a server connection instead"); + return serverOrCancellationRequest.CancellationRequest.Value; + } public async ValueTask DisposeAsync() { diff --git a/test/Npgsql.Tests/Support/PgServerMock.cs b/test/Npgsql.Tests/Support/PgServerMock.cs index 16c32c4515..5f3f377e30 100644 --- a/test/Npgsql.Tests/Support/PgServerMock.cs +++ b/test/Npgsql.Tests/Support/PgServerMock.cs @@ -215,6 +215,29 @@ internal PgServerMock WriteBackendKeyData(int processId, int secret) return this; } + internal PgServerMock WriteErrorResponse(string code) + => WriteErrorResponse(code, "ERROR", "MOCK ERROR MESSAGE"); + + internal PgServerMock WriteErrorResponse(string code, string severity, string message) + { + CheckDisposed(); + _writeBuffer.WriteByte((byte)BackendMessageCode.ErrorResponse); + _writeBuffer.WriteInt32( + 4 + + 1 + Encoding.GetByteCount(code) + + 1 + Encoding.GetByteCount(severity) + + 1 + Encoding.GetByteCount(message) + + 1); + _writeBuffer.WriteByte((byte)ErrorOrNoticeMessage.ErrorFieldTypeCode.Code); + _writeBuffer.WriteNullTerminatedString(code); + _writeBuffer.WriteByte((byte)ErrorOrNoticeMessage.ErrorFieldTypeCode.Severity); + _writeBuffer.WriteNullTerminatedString(severity); + _writeBuffer.WriteByte((byte)ErrorOrNoticeMessage.ErrorFieldTypeCode.Message); + _writeBuffer.WriteNullTerminatedString(message); + _writeBuffer.WriteByte((byte)ErrorOrNoticeMessage.ErrorFieldTypeCode.Done); + return this; + } + #endregion Low-level message writing void CheckDisposed()