diff --git a/src/Npgsql/NpgsqlBinaryExporter.cs b/src/Npgsql/NpgsqlBinaryExporter.cs index 0472f73dfb..c1bd27325d 100644 --- a/src/Npgsql/NpgsqlBinaryExporter.cs +++ b/src/Npgsql/NpgsqlBinaryExporter.cs @@ -52,7 +52,7 @@ internal NpgsqlBinaryExporter(NpgsqlConnector connector, string copyToCommand) _connector.Flush(); CopyOutResponseMessage copyOutResponse; - var msg = _connector.ReadMessage(); + var msg = _connector.ReadMessageWithoutCancellation(); switch (msg.Code) { case BackendMessageCode.CopyOutResponse: @@ -144,9 +144,9 @@ async ValueTask StartRow(bool async, CancellationToken cancellationToken = if (numColumns == -1) { Debug.Assert(_leftToReadInDataMsg == 0); - Expect(await _connector.ReadMessage(async, cancellationToken), _connector); - Expect(await _connector.ReadMessage(async, cancellationToken), _connector); - Expect(await _connector.ReadMessage(async, cancellationToken), _connector); + Expect(await _connector.ReadMessageWithoutCancellation(async, cancellationToken), _connector); + Expect(await _connector.ReadMessageWithoutCancellation(async, cancellationToken), _connector); + Expect(await _connector.ReadMessageWithoutCancellation(async, cancellationToken), _connector); _column = -1; _isConsumed = true; return -1; @@ -384,8 +384,8 @@ async ValueTask DisposeAsync(bool async) // Read to the end _connector.SkipUntil(BackendMessageCode.CopyDone); // We intentionally do not pass a CancellationToken since we don't want to cancel cleanup - Expect(await _connector.ReadMessage(async, cancellationToken: default), _connector); - Expect(await _connector.ReadMessage(async, cancellationToken: default), _connector); + Expect(await _connector.ReadMessageWithoutCancellation(async, cancellationToken: default), _connector); + Expect(await _connector.ReadMessageWithoutCancellation(async, cancellationToken: default), _connector); } _connector.EndUserAction(); diff --git a/src/Npgsql/NpgsqlBinaryImporter.cs b/src/Npgsql/NpgsqlBinaryImporter.cs index d52b22724a..6c39942a98 100644 --- a/src/Npgsql/NpgsqlBinaryImporter.cs +++ b/src/Npgsql/NpgsqlBinaryImporter.cs @@ -56,7 +56,7 @@ internal NpgsqlBinaryImporter(NpgsqlConnector connector, string copyFromCommand) _connector.Flush(); CopyInResponseMessage copyInResponse; - var msg = _connector.ReadMessage(); + var msg = _connector.ReadMessageWithoutCancellation(); switch (msg.Code) { case BackendMessageCode.CopyInResponse: @@ -408,8 +408,8 @@ async ValueTask Complete(bool async, CancellationToken cancellationToken _buf.EndCopyMode(); await _connector.WriteCopyDone(async, cancellationToken); await _connector.Flush(async, cancellationToken); - var cmdComplete = Expect(await _connector.ReadMessage(async, cancellationToken), _connector); - Expect(await _connector.ReadMessage(async, cancellationToken), _connector); + var cmdComplete = Expect(await _connector.ReadMessageWithoutCancellation(async, cancellationToken), _connector); + Expect(await _connector.ReadMessageWithoutCancellation(async, cancellationToken), _connector); _state = ImporterState.Committed; return cmdComplete.Rows; } @@ -447,7 +447,7 @@ async Task Cancel(bool async, CancellationToken cancellationToken = default) await _connector.Flush(async, cancellationToken); try { - var msg = await _connector.ReadMessage(async, cancellationToken); + var msg = await _connector.ReadMessageWithoutCancellation(async, cancellationToken); // The CopyFail should immediately trigger an exception from the read above. throw _connector.Break( new NpgsqlException("Expected ErrorResponse when cancelling COPY but got: " + msg.Code)); diff --git a/src/Npgsql/NpgsqlCommand.cs b/src/Npgsql/NpgsqlCommand.cs index 52175f5444..67717c22f3 100644 --- a/src/Npgsql/NpgsqlCommand.cs +++ b/src/Npgsql/NpgsqlCommand.cs @@ -480,8 +480,8 @@ void DeriveParametersForQuery(NpgsqlConnector connector) foreach (var statement in _statements) { - Expect(connector.ReadMessage(), connector); - var paramTypeOIDs = Expect(connector.ReadMessage(), connector).TypeOIDs; + Expect(connector.ReadMessageWithoutCancellation(), connector); + var paramTypeOIDs = Expect(connector.ReadMessageWithoutCancellation(), connector).TypeOIDs; if (statement.InputParameters.Count != paramTypeOIDs.Count) { @@ -515,7 +515,7 @@ void DeriveParametersForQuery(NpgsqlConnector connector) } } - var msg = connector.ReadMessage(); + var msg = connector.ReadMessageWithoutCancellation(); switch (msg.Code) { case BackendMessageCode.RowDescription: @@ -526,7 +526,7 @@ void DeriveParametersForQuery(NpgsqlConnector connector) } } - Expect(connector.ReadMessage(), connector); + Expect(connector.ReadMessageWithoutCancellation(), connector); sendTask.GetAwaiter().GetResult(); } } diff --git a/src/Npgsql/NpgsqlConnector.Auth.cs b/src/Npgsql/NpgsqlConnector.Auth.cs index dbd0448e58..45288671ba 100644 --- a/src/Npgsql/NpgsqlConnector.Auth.cs +++ b/src/Npgsql/NpgsqlConnector.Auth.cs @@ -22,7 +22,7 @@ async Task Authenticate(string username, NpgsqlTimeout timeout, bool async, Canc Log.Trace("Authenticating...", Id); timeout.CheckAndApply(this); - var msg = Expect(await ReadMessage(async, cancellationToken), this); + var msg = Expect(await ReadMessageWithoutCancellation(async, cancellationToken), this); switch (msg.AuthRequestType) { case AuthenticationRequestType.AuthenticationOk: @@ -64,7 +64,7 @@ async Task AuthenticateCleartext(string username, bool async, CancellationToken await WritePassword(encoded, async, cancellationToken); await Flush(async, cancellationToken); - Expect(await ReadMessage(async, cancellationToken), this); + Expect(await ReadMessageWithoutCancellation(async, cancellationToken), this); } async Task AuthenticateSASL(List mechanisms, string username, bool async, CancellationToken cancellationToken = default) @@ -164,7 +164,7 @@ async Task AuthenticateSASL(List mechanisms, string username, bool async await WriteSASLInitialResponse(mechanism, PGUtil.UTF8Encoding.GetBytes($"{cbindFlag},,n=*,r={clientNonce}"), async, cancellationToken); await Flush(async, cancellationToken); - var saslContinueMsg = Expect(await ReadMessage(async, cancellationToken), this); + var saslContinueMsg = Expect(await ReadMessageWithoutCancellation(async, cancellationToken), this); if (saslContinueMsg.AuthRequestType != AuthenticationRequestType.AuthenticationSASLContinue) throw new NpgsqlException("[SASL] AuthenticationSASLFinal message expected"); var firstServerMsg = AuthenticationSCRAMServerFirstMessage.Load(saslContinueMsg.Payload); @@ -197,7 +197,7 @@ async Task AuthenticateSASL(List mechanisms, string username, bool async await WriteSASLResponse(Encoding.UTF8.GetBytes(messageStr), async, cancellationToken); await Flush(async, cancellationToken); - var saslFinalServerMsg = Expect(await ReadMessage(async, cancellationToken), this); + var saslFinalServerMsg = Expect(await ReadMessageWithoutCancellation(async, cancellationToken), this); if (saslFinalServerMsg.AuthRequestType != AuthenticationRequestType.AuthenticationSASLFinal) throw new NpgsqlException("[SASL] AuthenticationSASLFinal message expected"); @@ -205,7 +205,7 @@ async Task AuthenticateSASL(List mechanisms, string username, bool async if (scramFinalServerMsg.ServerSignature != Convert.ToBase64String(serverSignature)) throw new NpgsqlException("[SCRAM] Unable to verify server signature"); - var okMsg = Expect(await ReadMessage(async, cancellationToken), this); + var okMsg = Expect(await ReadMessageWithoutCancellation(async, cancellationToken), this); if (okMsg.AuthRequestType != AuthenticationRequestType.AuthenticationOk) throw new NpgsqlException("[SASL] Expected AuthenticationOK message"); @@ -297,7 +297,7 @@ async Task AuthenticateMD5(string username, byte[] salt, bool async, Cancellatio await WritePassword(result, async, cancellationToken); await Flush(async, cancellationToken); - Expect(await ReadMessage(async, cancellationToken), this); + Expect(await ReadMessageWithoutCancellation(async, cancellationToken), this); } async Task AuthenticateGSS(bool async) @@ -393,7 +393,7 @@ async Task Read(byte[] buffer, int offset, int count, bool async, Cancellat { if (_leftToRead == 0) { - var response = Expect(await _connector.ReadMessage(async, cancellationToken), _connector); + var response = Expect(await _connector.ReadMessageWithoutCancellation(async, cancellationToken), _connector); if (response.AuthRequestType == AuthenticationRequestType.AuthenticationOk) throw new AuthenticationCompleteException(); var gssMsg = response as AuthenticationGSSContinueMessage; diff --git a/src/Npgsql/NpgsqlConnector.cs b/src/Npgsql/NpgsqlConnector.cs index 3db2356d49..085287f7ae 100644 --- a/src/Npgsql/NpgsqlConnector.cs +++ b/src/Npgsql/NpgsqlConnector.cs @@ -223,6 +223,9 @@ internal void FlagAsWritableForMultiplexing() volatile bool _cancellationRequested; volatile bool _userCancellationRequested; + + internal bool UserCancellationRequested => _userCancellationRequested; + internal CancellationToken UserCancellationToken { get; set; } static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(NpgsqlConnector)); @@ -434,13 +437,13 @@ internal async Task Open(NpgsqlTimeout timeout, bool async, CancellationToken ca // We treat BackendKeyData as optional because some PostgreSQL-like database // don't send it (CockroachDB, CrateDB) - var msg = await ReadMessage(async, cancellationToken); + var msg = await ReadMessageWithoutCancellation(async, cancellationToken); if (msg.Code == BackendMessageCode.BackendKeyData) { var keyDataMsg = (BackendKeyDataMessage)msg; BackendProcessId = keyDataMsg.BackendProcessId; _backendSecretKey = keyDataMsg.BackendSecretKey; - msg = await ReadMessage(async, cancellationToken); + msg = await ReadMessageWithoutCancellation(async, cancellationToken); } if (msg.Code != BackendMessageCode.ReadyForQuery) throw new NpgsqlException($"Received backend message {msg.Code} while expecting ReadyForQuery. Please file a bug."); @@ -997,8 +1000,17 @@ internal void PrependInternalMessage(byte[] rawMessage, int responseMessageCount #region Backend message processing + internal IBackendMessage ReadMessageWithoutCancellation(DataRowLoadingMode dataRowLoadingMode = DataRowLoadingMode.NonSequential) + => ReadMessageWithoutCancellation(false, dataRowLoadingMode, cancellationToken: default).GetAwaiter().GetResult()!; + + internal ValueTask ReadMessageWithoutCancellation(bool async, DataRowLoadingMode dataRowLoadingMode = DataRowLoadingMode.NonSequential, CancellationToken cancellationToken = default) + => DoReadMessage(async, dataRowLoadingMode, attemptPostgresCancellation: false, cancellationToken: cancellationToken)!; + + internal ValueTask ReadMessageWithoutCancellation(bool async, CancellationToken cancellationToken = default) + => DoReadMessage(async, DataRowLoadingMode.NonSequential, attemptPostgresCancellation: false, cancellationToken: cancellationToken)!; + internal IBackendMessage ReadMessage(DataRowLoadingMode dataRowLoadingMode = DataRowLoadingMode.NonSequential) - => ReadMessage(false, dataRowLoadingMode, cancellationToken: default).GetAwaiter().GetResult()!; + => ReadMessage(false, dataRowLoadingMode, cancellationToken: default).GetAwaiter().GetResult(); internal ValueTask ReadMessage(bool async, CancellationToken cancellationToken = default) => DoReadMessage(async, DataRowLoadingMode.NonSequential, cancellationToken: cancellationToken)!; @@ -1007,12 +1019,14 @@ internal ValueTask ReadMessage(bool async, DataRowLoadingMode d => DoReadMessage(async, dataRowLoadingMode, cancellationToken: cancellationToken)!; internal ValueTask ReadMessageWithNotifications(bool async, CancellationToken cancellationToken = default) - => DoReadMessage(async, DataRowLoadingMode.NonSequential, true, cancellationToken: cancellationToken); + => DoReadMessage(async, DataRowLoadingMode.NonSequential, readingNotifications: true, attemptPostgresCancellation: false, + cancellationToken: cancellationToken); ValueTask DoReadMessage( bool async, DataRowLoadingMode dataRowLoadingMode = DataRowLoadingMode.NonSequential, bool readingNotifications = false, + bool attemptPostgresCancellation = true, CancellationToken cancellationToken = default) { if (_pendingPrependedResponses > 0 || @@ -1020,7 +1034,8 @@ internal ValueTask ReadMessage(bool async, DataRowLoadingMode d readingNotifications || ReadBuffer.ReadBytesLeft < 5) { - return ReadMessageLong(dataRowLoadingMode, readingNotifications, cancellationToken2: cancellationToken); + return ReadMessageLong(dataRowLoadingMode, readingNotifications2: readingNotifications, + attemptPostgresCancellation2: attemptPostgresCancellation, cancellationToken2: cancellationToken); } var messageCode = (BackendMessageCode)ReadBuffer.ReadByte(); @@ -1031,7 +1046,8 @@ internal ValueTask ReadMessage(bool async, DataRowLoadingMode d case BackendMessageCode.ParameterStatus: case BackendMessageCode.ErrorResponse: ReadBuffer.ReadPosition--; - return ReadMessageLong(dataRowLoadingMode, readingNotifications2: false, cancellationToken2: cancellationToken); + return ReadMessageLong(dataRowLoadingMode, readingNotifications2: false, attemptPostgresCancellation2: false, + cancellationToken2: cancellationToken); case BackendMessageCode.ReadyForQuery: break; } @@ -1041,7 +1057,8 @@ internal ValueTask ReadMessage(bool async, DataRowLoadingMode d if (len > ReadBuffer.ReadBytesLeft) { ReadBuffer.ReadPosition -= 5; - return ReadMessageLong(dataRowLoadingMode, readingNotifications2: false, cancellationToken2: cancellationToken); + return ReadMessageLong(dataRowLoadingMode, readingNotifications2: false, attemptPostgresCancellation2: false, + cancellationToken2: cancellationToken); } return new ValueTask(ParseServerMessage(ReadBuffer, messageCode, len, false)); @@ -1050,6 +1067,7 @@ internal ValueTask ReadMessage(bool async, DataRowLoadingMode d DataRowLoadingMode dataRowLoadingMode2, bool readingNotifications2, bool isReadingPrependedMessage = false, + bool attemptPostgresCancellation2 = true, CancellationToken cancellationToken2 = default) { // First read the responses of any prepended messages. @@ -1060,7 +1078,8 @@ internal ValueTask ReadMessage(bool async, DataRowLoadingMode d // TODO: There could be room for optimization here, rather than the async call(s) ReadBuffer.Timeout = TimeSpan.FromMilliseconds(InternalCommandTimeout); for (; _pendingPrependedResponses > 0; _pendingPrependedResponses--) - await ReadMessageLong(DataRowLoadingMode.Skip, false, true, cancellationToken2); + await ReadMessageLong(DataRowLoadingMode.Skip, readingNotifications2: false, isReadingPrependedMessage: true, + attemptPostgresCancellation2: false, cancellationToken2); } catch (PostgresException e) { @@ -1076,7 +1095,7 @@ internal ValueTask ReadMessage(bool async, DataRowLoadingMode d while (true) { - await ReadBuffer.Ensure(5, async, readingNotifications2, cancellationToken2); + await ReadBuffer.Ensure(5, async, readingNotifications2, attemptPostgresCancellation2, cancellationToken2); messageCode = (BackendMessageCode)ReadBuffer.ReadByte(); PGUtil.ValidateBackendMessageCode(messageCode); len = ReadBuffer.ReadInt32() - 4; // Transmitted length includes itself @@ -1178,11 +1197,6 @@ internal ValueTask ReadMessage(bool async, DataRowLoadingMode d throw; } - catch (NpgsqlException e) when (e.InnerException is TimeoutException && _userCancellationRequested) - { - // User requested the cancellation and it timed out - throw new OperationCanceledException("Query was cancelled", e.InnerException, UserCancellationToken); - } catch (NpgsqlException) { // An ErrorResponse isn't followed by ReadyForQuery @@ -1291,7 +1305,7 @@ internal IBackendMessage SkipUntil(BackendMessageCode stopAt) while (true) { - var msg = ReadMessage(false, DataRowLoadingMode.Skip).GetAwaiter().GetResult()!; + var msg = ReadMessageWithoutCancellation(false, DataRowLoadingMode.Skip).GetAwaiter().GetResult()!; Debug.Assert(!(msg is DataRowMessage)); if (msg.Code == stopAt) return msg; @@ -2016,7 +2030,7 @@ internal async Task Wait(bool async, int timeout, CancellationToken cancel while (true) { - var msg = await ReadMessage(async, cancellationToken); + var msg = await ReadMessageWithoutCancellation(async, cancellationToken); if (msg == null) { receivedNotification = true; @@ -2077,8 +2091,8 @@ internal async Task ExecuteInternalCommand(string query, bool async, Cancellatio await WriteQuery(query, async, cancellationToken); await Flush(async, cancellationToken); - Expect(await ReadMessage(async, cancellationToken), this); - Expect(await ReadMessage(async, cancellationToken), this); + Expect(await ReadMessageWithoutCancellation(async, cancellationToken), this); + Expect(await ReadMessageWithoutCancellation(async, cancellationToken), this); } internal async Task ExecuteInternalCommand(byte[] data, bool async, CancellationToken cancellationToken = default) @@ -2089,8 +2103,8 @@ internal async Task ExecuteInternalCommand(byte[] data, bool async, Cancellation await WritePregenerated(data, async, cancellationToken); await Flush(async, cancellationToken); - Expect(await ReadMessage(async, cancellationToken), this); - Expect(await ReadMessage(async, cancellationToken), this); + Expect(await ReadMessageWithoutCancellation(async, cancellationToken), this); + Expect(await ReadMessageWithoutCancellation(async, cancellationToken), this); } #endregion diff --git a/src/Npgsql/NpgsqlRawCopyStream.cs b/src/Npgsql/NpgsqlRawCopyStream.cs index d392277323..c28dc2c320 100644 --- a/src/Npgsql/NpgsqlRawCopyStream.cs +++ b/src/Npgsql/NpgsqlRawCopyStream.cs @@ -61,7 +61,7 @@ internal NpgsqlRawCopyStream(NpgsqlConnector connector, string copyCommand) _connector.WriteQuery(copyCommand); _connector.Flush(); - var msg = _connector.ReadMessage(); + var msg = _connector.ReadMessageWithoutCancellation(); switch (msg.Code) { case BackendMessageCode.CopyInResponse: @@ -270,7 +270,7 @@ async ValueTask ReadCore(int count, bool async, CancellationToken cancellat { // We've consumed the current DataMessage (or haven't yet received the first), // read the next message - msg = await _connector.ReadMessage(async, cancellationToken); + msg = await _connector.ReadMessageWithoutCancellation(async, cancellationToken); } catch { @@ -284,8 +284,8 @@ async ValueTask ReadCore(int count, bool async, CancellationToken cancellat _leftToReadInDataMsg = ((CopyDataMessage)msg).Length; break; case BackendMessageCode.CopyDone: - Expect(await _connector.ReadMessage(async, cancellationToken), _connector); - Expect(await _connector.ReadMessage(async, cancellationToken), _connector); + Expect(await _connector.ReadMessageWithoutCancellation(async, cancellationToken), _connector); + Expect(await _connector.ReadMessageWithoutCancellation(async, cancellationToken), _connector); _isConsumed = true; return 0; default: @@ -340,7 +340,7 @@ async Task Cancel(bool async) await _connector.Flush(async); try { - var msg = await _connector.ReadMessage(async, cancellationToken: default); + var msg = await _connector.ReadMessageWithoutCancellation(async, cancellationToken: default); // The CopyFail should immediately trigger an exception from the read above. throw _connector.Break( new NpgsqlException("Expected ErrorResponse when cancelling COPY but got: " + msg.Code)); @@ -379,8 +379,8 @@ async ValueTask DisposeAsync(bool disposing, bool async) _writeBuf.EndCopyMode(); await _connector.WriteCopyDone(async); await _connector.Flush(async); - Expect(await _connector.ReadMessage(async, cancellationToken: default), _connector); - Expect(await _connector.ReadMessage(async, cancellationToken : default), _connector); + Expect(await _connector.ReadMessageWithoutCancellation(async, cancellationToken: default), _connector); + Expect(await _connector.ReadMessageWithoutCancellation(async, cancellationToken : default), _connector); } else { diff --git a/src/Npgsql/NpgsqlReadBuffer.cs b/src/Npgsql/NpgsqlReadBuffer.cs index 33fbc48868..50fc0d75de 100644 --- a/src/Npgsql/NpgsqlReadBuffer.cs +++ b/src/Npgsql/NpgsqlReadBuffer.cs @@ -128,16 +128,17 @@ internal void Ensure(int count) } public Task Ensure(int count, bool async, CancellationToken cancellationToken = default) - => Ensure(count, async, readingNotifications: false, cancellationToken); + => Ensure(count, async, readingNotifications: false, attemptPostgresCancellation: false, cancellationToken); public Task EnsureAsync(int count, CancellationToken cancellationToken = default) - => Ensure(count, async: true, readingNotifications: false, cancellationToken); + => Ensure(count, async: true, readingNotifications: false, attemptPostgresCancellation: false, cancellationToken); /// /// Ensures that bytes are available in the buffer, and if /// not, reads from the socket until enough is available. /// - internal Task Ensure(int count, bool async, bool readingNotifications, CancellationToken cancellationToken = default) + internal Task Ensure(int count, bool async, bool readingNotifications, bool attemptPostgresCancellation, + CancellationToken cancellationToken = default) { return count <= ReadBytesLeft ? Task.CompletedTask : EnsureLong(); @@ -202,10 +203,9 @@ async Task EnsureLong() switch (e) { - // User requested the cancellation (at this moment, it should be only WaitAsync) + // User requested the cancellation (at this moment, it is COPY operations, WaitAsync, Reader's sequential methods, authentication) case OperationCanceledException _ when cancellationToken.IsCancellationRequested: - Debug.Assert(readingNotifications); - throw; + throw readingNotifications ? e : Connector.Break(e); // Read timeout case OperationCanceledException _: @@ -216,11 +216,12 @@ async Task EnsureLong() Debug.Assert(e is OperationCanceledException ? async : !async); if (readingNotifications) - throw TimeoutException(); + throw NpgsqlTimeoutException(); // Note that if PG cancellation fails, the exception for that is already logged internally by CancelRequest. // We simply continue and throw the timeout one. - if (!wasCancellationRequested && Connector.CancelRequest(requestedByUser: false)) + // TODO: As an optimization, we can still attempt to send a cancellation request, but after that immediately break the connection + if (attemptPostgresCancellation && !wasCancellationRequested && Connector.CancelRequest(requestedByUser: false)) { // If the cancellation timeout is negative, we break the connection immediately var cancellationTimeout = Connector.Settings.CancellationTimeout; @@ -238,7 +239,16 @@ async Task EnsureLong() } } - throw Connector.Break(TimeoutException()); + // There is a case, when we might call a cancellable method (NpgsqlDataReader.NextResult) + // but it times out on a sequential read (NpgsqlDataReader.ConsumeRow) + if (Connector.UserCancellationRequested) + { + // User requested the cancellation and it timed out (or we didn't send it) + throw Connector.Break(new OperationCanceledException("Query was cancelled", TimeoutException(), + Connector.UserCancellationToken)); + } + + throw Connector.Break(NpgsqlTimeoutException()); } default: @@ -250,8 +260,9 @@ async Task EnsureLong() Cts.Stop(); NpgsqlEventSource.Log.BytesRead(totalRead); - static Exception TimeoutException() - => new NpgsqlException("Exception while reading from stream", new TimeoutException("Timeout during reading attempt")); + static Exception NpgsqlTimeoutException() => new NpgsqlException("Exception while reading from stream", TimeoutException()); + + static Exception TimeoutException() => new TimeoutException("Timeout during reading attempt"); } } diff --git a/test/Npgsql.Tests/ReaderTests.cs b/test/Npgsql.Tests/ReaderTests.cs index 403fe183cb..a37e798e94 100644 --- a/test/Npgsql.Tests/ReaderTests.cs +++ b/test/Npgsql.Tests/ReaderTests.cs @@ -1075,7 +1075,6 @@ public async Task ManyReads() } } - [Test] public async Task NullableScalar() { @@ -1681,7 +1680,7 @@ await pgMock Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); } - [Test, Description("Cancels ReadAsync via the cancellation token, with unsuccessful PG cancellation (socket break)")] + [Test, Description("Cancels NextResultAsync via the cancellation token, with unsuccessful PG cancellation (socket break)")] public async Task NextResultAsync_cancel_hard() { if (IsMultiplexing) @@ -1724,8 +1723,205 @@ await pgMock Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); } + [Test, Description("Cancels sequential NextResultAsync via the cancellation token")] + public async Task NextResultAsync_sequential_cancel() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation + + if (!IsSequential) + return; + + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); + await using var conn = await OpenConnectionAsync(connectionString); + + // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Bytea)) + .WriteDataRowWithFlush(new byte[10000]); + + using var cmd = new NpgsqlCommand("SELECT some_bytea FROM some_table", conn); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + + // Successfully read the first resultset + Assert.True(await reader.ReadAsync()); + + // Attempt to read the second row - simulate blocking and cancellation + var cancellationSource = new CancellationTokenSource(); + var task = reader.NextResultAsync(cancellationSource.Token); + cancellationSource.Cancel(); + + var (processId, _) = await postmasterMock.WaitForCancellationRequest(); + Assert.That(processId, Is.EqualTo(conn.ProcessID)); + + // Send no response from server, wait for the cancellation attempt to time out + var exception = Assert.ThrowsAsync(async () => await task); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); + + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } + + [Test, Description("Cancels sequential ReadAsGetFieldValueAsync")] + public async Task GetFieldValueAsync_sequential_cancel() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation + + if (!IsSequential) + return; + + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); + await using var conn = await OpenConnectionAsync(connectionString); + + // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Bytea)) + .WriteDataRowWithFlush(new byte[10000]); + + using var cmd = new NpgsqlCommand("SELECT some_bytea FROM some_table", conn); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + + await reader.ReadAsync(); + + using var cts = new CancellationTokenSource(); + var task = reader.GetFieldValueAsync(0, cts.Token); + cts.Cancel(); + + var exception = Assert.ThrowsAsync(async () => await task); + Assert.That(exception.InnerException, Is.Null); + + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } + + [Test, Description("Cancels sequential ReadAsGetFieldValueAsync")] + public async Task IsDBNullAsync_sequential_cancel() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation + + if (!IsSequential) + return; + + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); + await using var conn = await OpenConnectionAsync(connectionString); + + // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Bytea), new FieldDescription(PostgresTypeOIDs.Int4)) + .WriteDataRowWithFlush(new byte[10000], new byte[4]); + + using var cmd = new NpgsqlCommand("SELECT some_bytea, some_int FROM some_table", conn); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + + await reader.ReadAsync(); + + using var cts = new CancellationTokenSource(); + var task = reader.IsDBNullAsync(1, cts.Token); + cts.Cancel(); + + var exception = Assert.ThrowsAsync(async () => await task); + Assert.That(exception.InnerException, Is.Null); + + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } + #endregion Cancellation + #region Timeout + + [Test, Description("Timeouts sequential ReadAsGetFieldValueAsync")] + [Timeout(10000)] + public async Task GetFieldValueAsync_sequential_timeout() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation + + if (!IsSequential) + return; + + var csb = new NpgsqlConnectionStringBuilder(ConnectionString); + csb.CommandTimeout = 3; + csb.CancellationTimeout = 15000; + + await using var postmasterMock = PgPostmasterMock.Start(csb.ToString()); + using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); + await using var conn = await OpenConnectionAsync(connectionString); + + // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Bytea)) + .WriteDataRowWithFlush(new byte[10000]); + + using var cmd = new NpgsqlCommand("SELECT some_bytea FROM some_table", conn); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + + await reader.ReadAsync(); + + var task = reader.GetFieldValueAsync(0); + + var exception = Assert.ThrowsAsync(async () => await task); + Assert.That(exception.InnerException, Is.TypeOf()); + + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } + + [Test, Description("Timeouts sequential IsDBNullAsync")] + [Timeout(10000)] + public async Task IsDBNullAsync_sequential_timeout() + { + if (IsMultiplexing) + return; // Multiplexing, cancellation + + if (!IsSequential) + return; + + var csb = new NpgsqlConnectionStringBuilder(ConnectionString); + csb.CommandTimeout = 3; + csb.CancellationTimeout = 15000; + + await using var postmasterMock = PgPostmasterMock.Start(csb.ToString()); + using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); + await using var conn = await OpenConnectionAsync(connectionString); + + // Write responses to the query we're about to send, with a single data row (we'll attempt to read two) + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Bytea), new FieldDescription(PostgresTypeOIDs.Int4)) + .WriteDataRowWithFlush(new byte[10000], new byte[4]); + + using var cmd = new NpgsqlCommand("SELECT some_bytea, some_int FROM some_table", conn); + await using var reader = await cmd.ExecuteReaderAsync(Behavior); + + await reader.ReadAsync(); + + var task = reader.GetFieldValueAsync(0); + + var exception = Assert.ThrowsAsync(async () => await task); + Assert.That(exception.InnerException, Is.TypeOf()); + + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); + } + + #endregion + #region Initialization / setup / teardown // ReSharper disable InconsistentNaming diff --git a/test/Npgsql.Tests/Support/PgServerMock.cs b/test/Npgsql.Tests/Support/PgServerMock.cs index 5f3f377e30..b30ce89bd7 100644 --- a/test/Npgsql.Tests/Support/PgServerMock.cs +++ b/test/Npgsql.Tests/Support/PgServerMock.cs @@ -158,6 +158,21 @@ internal PgServerMock WriteDataRow(params byte[][] columnValues) return this; } + internal async Task WriteDataRowWithFlush(params byte[][] columnValues) + { + CheckDisposed(); + + _writeBuffer.WriteByte((byte) BackendMessageCode.DataRow); + _writeBuffer.WriteInt32(4 + 2 + columnValues.Sum(v => 4 + v.Length)); + _writeBuffer.WriteInt16(columnValues.Length); + + foreach (var field in columnValues) + { + _writeBuffer.WriteInt32(field.Length); + await _writeBuffer.WriteBytesRaw(field, true); + } + } + internal PgServerMock WriteCommandComplete(string tag = "") { CheckDisposed();