From 7d0aef26d69dda5eb6695ee9e7bc32ea9d888ea1 Mon Sep 17 00:00:00 2001 From: Nino Floris Date: Fri, 16 Feb 2024 19:53:57 +0100 Subject: [PATCH 1/6] WriteBind remove casts and add some protocol limits --- .../NpgsqlConnector.FrontendMessages.cs | 106 ++++++++++-------- 1 file changed, 62 insertions(+), 44 deletions(-) diff --git a/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs b/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs index 9e0fd45dd3..db1eaf38d2 100644 --- a/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs +++ b/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs @@ -165,81 +165,93 @@ internal async Task WriteBind( bool async, CancellationToken cancellationToken = default) { + const short textFormatCode = 0; + const short binaryFormatCode = 1; + + Debug.Assert(portal == string.Empty); + NpgsqlWriteBuffer.AssertASCIIOnly(asciiName); NpgsqlWriteBuffer.AssertASCIIOnly(portal); - var headerLength = - sizeof(byte) + // Message code - sizeof(int) + // Message length - sizeof(byte) + // Portal is always empty (only a null terminator) - asciiName.Length + sizeof(byte) + // Statement name plus null terminator - sizeof(ushort); // Number of parameter format codes that follow + if (parameters.Count > ushort.MaxValue) + ThrowHelper.ThrowArgumentException("Too many parameters in statement (max: 65535).", nameof(parameters)); + var parameterCount = (ushort)parameters.Count; + + // PG limit is 1664 (see https://www.postgresql.org/docs/current/limits.html) however we purely guard datatype limits here. + if (unknownResultTypeList?.Length > short.MaxValue) + ThrowHelper.ThrowArgumentException("Too many result types in statement (max: 32768).", nameof(unknownResultTypeList)); + var unknownResultTypeCount = (short)(unknownResultTypeList?.Length ?? 1); - var writeBuffer = WriteBuffer; var formatCodesSum = 0; - var paramsLength = 0; - for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++) + var parametersLength = 0; + for (var i = 0; i < parameters.Count; i++) { - var param = parameters[paramIndex]; + var param = parameters[i]; param.Bind(out var format, out var size); - paramsLength += size.Value > 0 ? size.Value : 0; + parametersLength += size.Value > 0 ? size.Value : 0; formatCodesSum += format.ToFormatCode(); } + var formatCodeListLength = formatCodesSum is 0 ? (ushort)0 : formatCodesSum == parameterCount ? (ushort)1 : parameterCount; - var formatCodeListLength = formatCodesSum == 0 ? 0 : formatCodesSum == parameters.Count ? 1 : parameters.Count; + var headerLength = + sizeof(byte) + // Message code + sizeof(int) + // Message length + sizeof(byte) + // Portal (only a null terminator) + asciiName.Length + sizeof(byte) + // Statement name plus null terminator + sizeof(ushort); // Number of parameter format codes that follow - var messageLength = headerLength + - sizeof(short) * formatCodeListLength + // List of format codes - sizeof(short) + // Number of parameters - sizeof(int) * parameters.Count + // Parameter lengths - paramsLength + // Parameter values - sizeof(short) + // Number of result format codes - sizeof(short) * (unknownResultTypeList?.Length ?? 1); // Result format codes + var messageLength = + headerLength + + sizeof(short) * formatCodeListLength + // List of format codes + sizeof(short) + // Number of parameters + sizeof(int) * parameterCount + // Parameter lengths + parametersLength + // Parameter values + sizeof(short) + // Number of result format codes + sizeof(short) * unknownResultTypeCount; // Result format codes - WriteBuffer.StartMessage(messageLength); - if (WriteBuffer.WriteSpaceLeft < headerLength) + var writeBuffer = WriteBuffer; + writeBuffer.StartMessage(messageLength); + if (writeBuffer.WriteSpaceLeft < headerLength) { - Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header"); + Debug.Assert(writeBuffer.Size >= headerLength, "Write buffer too small for Bind header"); await Flush(async, cancellationToken).ConfigureAwait(false); } - WriteBuffer.WriteByte(FrontendMessageCode.Bind); - WriteBuffer.WriteInt32(messageLength - 1); - Debug.Assert(portal == string.Empty); - writeBuffer.WriteByte(0); // Portal is always empty - + writeBuffer.WriteByte(FrontendMessageCode.Bind); + writeBuffer.WriteInt32(messageLength - 1); + writeBuffer.WriteByte(0); // Portal is always empty writeBuffer.WriteNullTerminatedString(asciiName); - writeBuffer.WriteInt16((short)formatCodeListLength); - // 0 length implicitly means all-text, 1 means all-binary, >1 means mix-and-match - if (formatCodeListLength == 1) + // 0 length implicitly means all-text, 1 is a uniform format, >1 is per parameter formats + writeBuffer.WriteUInt16(formatCodeListLength); + if (formatCodeListLength is 1) { if (writeBuffer.WriteSpaceLeft < sizeof(short)) await Flush(async, cancellationToken).ConfigureAwait(false); - writeBuffer.WriteInt16(DataFormat.Binary.ToFormatCode()); + writeBuffer.WriteInt16(binaryFormatCode); } else if (formatCodeListLength > 1) { - for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++) + for (var i = 0; i < parameters.Count; i++) { if (writeBuffer.WriteSpaceLeft < sizeof(short)) await Flush(async, cancellationToken).ConfigureAwait(false); - writeBuffer.WriteInt16(parameters[paramIndex].Format.ToFormatCode()); + writeBuffer.WriteInt16(parameters[i].Format.ToFormatCode()); } } if (writeBuffer.WriteSpaceLeft < sizeof(ushort)) await Flush(async, cancellationToken).ConfigureAwait(false); - writeBuffer.WriteUInt16((ushort)parameters.Count); - if (parameters.Count > 0) + writeBuffer.WriteUInt16(parameterCount); + if (parameterCount > 0) { var writer = writeBuffer.GetWriter(DatabaseInfo, async ? FlushMode.NonBlocking : FlushMode.Blocking); try { - for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++) + for (var i = 0; i < parameters.Count; i++) { - var param = parameters[paramIndex]; + var param = parameters[i]; await param.Write(async, writer, cancellationToken).ConfigureAwait(false); } } @@ -252,18 +264,24 @@ internal async Task WriteBind( if (unknownResultTypeList != null) { - if (writeBuffer.WriteSpaceLeft < 2 + unknownResultTypeList.Length * 2) + if (writeBuffer.WriteSpaceLeft < sizeof(short) + unknownResultTypeCount * sizeof(short)) await Flush(async, cancellationToken).ConfigureAwait(false); - writeBuffer.WriteInt16((short)unknownResultTypeList.Length); - foreach (var t in unknownResultTypeList) - writeBuffer.WriteInt16((short)(t ? 0 : 1)); + + writeBuffer.WriteInt16(unknownResultTypeCount); + for (var i = 0; i < unknownResultTypeList.Length; i++) + { + var unknownResultType = unknownResultTypeList[i]; + writeBuffer.WriteInt16(unknownResultType ? textFormatCode : binaryFormatCode); + } } else { - if (writeBuffer.WriteSpaceLeft < 4) + if (writeBuffer.WriteSpaceLeft < sizeof(short) + sizeof(short)) await Flush(async, cancellationToken).ConfigureAwait(false); - writeBuffer.WriteInt16(1); - writeBuffer.WriteInt16((short)(allResultTypesAreUnknown ? 0 : 1)); + + Debug.Assert(unknownResultTypeCount is 1); + writeBuffer.WriteInt16(unknownResultTypeCount); + writeBuffer.WriteInt16(allResultTypesAreUnknown ? textFormatCode : binaryFormatCode); } } From 375f2aeab6c56a20d6de61c2d642fd126dc6159e Mon Sep 17 00:00:00 2001 From: Nino Floris Date: Fri, 16 Feb 2024 20:26:15 +0100 Subject: [PATCH 2/6] WriteParse remove casts and add some protocol limits --- .../NpgsqlConnector.FrontendMessages.cs | 51 +++++++++++-------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs b/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs index db1eaf38d2..53d457c54d 100644 --- a/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs +++ b/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs @@ -105,10 +105,15 @@ static void Write(NpgsqlWriteBuffer writeBuffer, int maxRows) } } - internal async Task WriteParse(string sql, byte[] asciiName, List inputParameters, bool async, CancellationToken cancellationToken = default) + internal async Task WriteParse(string sql, byte[] asciiName, List parameters, bool async, CancellationToken cancellationToken = default) { NpgsqlWriteBuffer.AssertASCIIOnly(asciiName); + // See https://www.postgresql.org/docs/current/limits.html + if (parameters.Count > ushort.MaxValue) + ThrowHelper.ThrowArgumentException("Too many parameters in statement (max: 65535).", nameof(parameters)); + var parameterCount = (ushort)parameters.Count; + int queryByteLen; try { @@ -120,39 +125,41 @@ internal async Task WriteParse(string sql, byte[] asciiName, List= headerLength, "Write buffer too small for Parse header"); await Flush(async, cancellationToken).ConfigureAwait(false); + } - WriteBuffer.WriteByte(FrontendMessageCode.Parse); - WriteBuffer.WriteInt32(messageLength - 1); - WriteBuffer.WriteNullTerminatedString(asciiName); + writeBuffer.WriteByte(FrontendMessageCode.Parse); + writeBuffer.WriteInt32(messageLength - 1); + writeBuffer.WriteNullTerminatedString(asciiName); await writeBuffer.WriteString(sql, queryByteLen, async, cancellationToken).ConfigureAwait(false); - - if (writeBuffer.WriteSpaceLeft < 1 + 2) + if (writeBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(short)) await Flush(async, cancellationToken).ConfigureAwait(false); writeBuffer.WriteByte(0); // Null terminator for the query - writeBuffer.WriteUInt16((ushort)inputParameters.Count); + writeBuffer.WriteUInt16(parameterCount); var databaseInfo = DatabaseInfo; - foreach (var p in inputParameters) + for (var i = 0; i < parameters.Count; i++) { - if (writeBuffer.WriteSpaceLeft < 4) + if (writeBuffer.WriteSpaceLeft < sizeof(uint)) await Flush(async, cancellationToken).ConfigureAwait(false); - - writeBuffer.WriteUInt32(databaseInfo.GetOid(p.PgTypeId).Value); + writeBuffer.WriteUInt32(databaseInfo.GetOid(parameters[i].PgTypeId).Value); } } From 6217bf8f5f6ea05f504539f197b7a5c11aab90d0 Mon Sep 17 00:00:00 2001 From: Nino Floris Date: Mon, 25 Mar 2024 14:07:50 +0100 Subject: [PATCH 3/6] WriteQuery remove casts and add some protocol limits --- .../NpgsqlConnector.FrontendMessages.cs | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs b/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs index 53d457c54d..708a9f11a4 100644 --- a/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs +++ b/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs @@ -326,22 +326,35 @@ static void Write(NpgsqlWriteBuffer writeBuffer, int len, StatementOrPortal type internal async Task WriteQuery(string sql, bool async, CancellationToken cancellationToken = default) { - var queryByteLen = TextEncoding.GetByteCount(sql); + int queryByteLen; + try + { + queryByteLen = TextEncoding.GetByteCount(sql); + } + catch (Exception e) + { + Break(e); + throw; + } - var len = sizeof(byte) + - sizeof(int) + // Message length (including self excluding code) - queryByteLen + // Query byte length - sizeof(byte); + var len = + sizeof(byte) + // Message code + sizeof(int) + // Message length + queryByteLen + // Query byte length + sizeof(byte); // Null terminator WriteBuffer.StartMessage(len); - if (WriteBuffer.WriteSpaceLeft < 1 + 4) + if (WriteBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(int)) + { + Debug.Assert(WriteBuffer.Size >= sizeof(byte) + sizeof(int), "Write buffer too small for Parse header"); await Flush(async, cancellationToken).ConfigureAwait(false); + } WriteBuffer.WriteByte(FrontendMessageCode.Query); WriteBuffer.WriteInt32(len - 1); await WriteBuffer.WriteString(sql, queryByteLen, async, cancellationToken).ConfigureAwait(false); - if (WriteBuffer.WriteSpaceLeft < 1) + if (WriteBuffer.WriteSpaceLeft < sizeof(byte)) await Flush(async, cancellationToken).ConfigureAwait(false); WriteBuffer.WriteByte(0); // Null terminator } From 1b1027d685bcf68d7cea15823e4b7e2ba37275fe Mon Sep 17 00:00:00 2001 From: Nino Floris Date: Thu, 30 May 2024 18:05:15 +0200 Subject: [PATCH 4/6] Simplify WritePassword --- .../Internal/NpgsqlConnector.FrontendMessages.cs | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs b/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs index 708a9f11a4..c171b6b569 100644 --- a/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs +++ b/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs @@ -477,16 +477,7 @@ internal async Task WritePassword(byte[] payload, int offset, int count, bool as WriteBuffer.WriteByte(FrontendMessageCode.Password); WriteBuffer.WriteInt32(sizeof(int) + count); - - if (count <= WriteBuffer.WriteSpaceLeft) - { - // The entire array fits in our WriteBuffer, copy it into the WriteBuffer as usual. - WriteBuffer.WriteBytes(payload, offset, count); - return; - } - - await WriteBuffer.Flush(async, cancellationToken).ConfigureAwait(false); - await WriteBuffer.DirectWrite(new ReadOnlyMemory(payload, offset, count), async, cancellationToken).ConfigureAwait(false); + await WriteBuffer.WriteBytesRaw(new ReadOnlyMemory(payload, offset, count), async, cancellationToken).ConfigureAwait(false); } internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialResponse, bool async, CancellationToken cancellationToken = default) From 52bb05954c2a105d08c7357b1a62c94ebbea2d52 Mon Sep 17 00:00:00 2001 From: Nino Floris Date: Thu, 30 May 2024 18:05:32 +0200 Subject: [PATCH 5/6] Remove unused WriteStreamRaw --- src/Npgsql/Internal/NpgsqlWriteBuffer.cs | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/src/Npgsql/Internal/NpgsqlWriteBuffer.cs b/src/Npgsql/Internal/NpgsqlWriteBuffer.cs index 821bb7e6b1..eeee0a9b60 100644 --- a/src/Npgsql/Internal/NpgsqlWriteBuffer.cs +++ b/src/Npgsql/Internal/NpgsqlWriteBuffer.cs @@ -422,30 +422,6 @@ static async Task WriteBytesLong(NpgsqlWriteBuffer buffer, bool async, ReadOnlyM } } - public async Task WriteStreamRaw(Stream stream, int count, bool async, CancellationToken cancellationToken = default) - { - while (count > 0) - { - if (WriteSpaceLeft == 0) - await Flush(async, cancellationToken).ConfigureAwait(false); - try - { - var read = async - ? await stream.ReadAsync(Buffer, WritePosition, Math.Min(WriteSpaceLeft, count), cancellationToken).ConfigureAwait(false) - : stream.Read(Buffer, WritePosition, Math.Min(WriteSpaceLeft, count)); - if (read == 0) - throw new EndOfStreamException(); - WritePosition += read; - count -= read; - } - catch (Exception e) - { - throw Connector.Break(new NpgsqlException("Exception while writing to stream", e)); - } - } - Debug.Assert(count == 0); - } - public void WriteNullTerminatedString(string s) { AssertASCIIOnly(s); From 8f3382e6a9a9ad909a35991183f361448da6bf7e Mon Sep 17 00:00:00 2001 From: Nino Floris Date: Mon, 25 Mar 2024 14:51:58 +0100 Subject: [PATCH 6/6] Break connector on message write failures --- .../NpgsqlConnector.FrontendMessages.cs | 37 +-- src/Npgsql/NpgsqlCommand.cs | 257 ++++++++++-------- 2 files changed, 155 insertions(+), 139 deletions(-) diff --git a/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs b/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs index c171b6b569..b721fbbbd0 100644 --- a/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs +++ b/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs @@ -114,16 +114,7 @@ internal async Task WriteParse(string sql, byte[] asciiName, List 0) { var writer = writeBuffer.GetWriter(DatabaseInfo, async ? FlushMode.NonBlocking : FlushMode.Blocking); - try - { - for (var i = 0; i < parameters.Count; i++) - { - var param = parameters[i]; - await param.Write(async, writer, cancellationToken).ConfigureAwait(false); - } - } - catch(Exception ex) - { - Break(ex); - throw; - } + for (var i = 0; i < parameters.Count; i++) + await parameters[i].Write(async, writer, cancellationToken).ConfigureAwait(false); } if (unknownResultTypeList != null) @@ -326,16 +306,7 @@ static void Write(NpgsqlWriteBuffer writeBuffer, int len, StatementOrPortal type internal async Task WriteQuery(string sql, bool async, CancellationToken cancellationToken = default) { - int queryByteLen; - try - { - queryByteLen = TextEncoding.GetByteCount(sql); - } - catch (Exception e) - { - Break(e); - throw; - } + var queryByteLen = TextEncoding.GetByteCount(sql); var len = sizeof(byte) + // Message code diff --git a/src/Npgsql/NpgsqlCommand.cs b/src/Npgsql/NpgsqlCommand.cs index 0a46e675cf..88aaef0561 100644 --- a/src/Npgsql/NpgsqlCommand.cs +++ b/src/Npgsql/NpgsqlCommand.cs @@ -1043,147 +1043,183 @@ internal Task Write(NpgsqlConnector connector, bool async, bool flush, Cancellat async Task WriteExecute(NpgsqlConnector connector, bool async, bool flush, CancellationToken cancellationToken) { - NpgsqlBatchCommand? batchCommand = null; - - var syncCaller = !async; - for (var i = 0; i < InternalBatchCommands.Count; i++) + try { - // The following is only for deadlock avoidance when doing sync I/O (so never in multiplexing) - if (syncCaller && ShouldSchedule(ref async, i)) - await new TaskSchedulerAwaitable(ConstrainedConcurrencyScheduler); + NpgsqlBatchCommand? batchCommand = null; - batchCommand = InternalBatchCommands[i]; - var pStatement = batchCommand.PreparedStatement; + var syncCaller = !async; + for (var i = 0; i < InternalBatchCommands.Count; i++) + { + // The following is only for deadlock avoidance when doing sync I/O (so never in multiplexing) + if (syncCaller && ShouldSchedule(ref async, i)) + await new TaskSchedulerAwaitable(ConstrainedConcurrencyScheduler); - Debug.Assert(batchCommand.FinalCommandText is not null); + batchCommand = InternalBatchCommands[i]; + var pStatement = batchCommand.PreparedStatement; - if (pStatement == null || batchCommand.IsPreparing) - { - // The statement should either execute unprepared, or is being auto-prepared. - // Send Parse, Bind, Describe + Debug.Assert(batchCommand.FinalCommandText is not null); - // We may have a prepared statement that replaces an existing statement - close the latter first. - if (pStatement?.StatementBeingReplaced != null) - await connector.WriteClose(StatementOrPortal.Statement, pStatement.StatementBeingReplaced.Name!, async, cancellationToken).ConfigureAwait(false); + if (pStatement == null || batchCommand.IsPreparing) + { + // The statement should either execute unprepared, or is being auto-prepared. + // Send Parse, Bind, Describe - await connector.WriteParse(batchCommand.FinalCommandText, batchCommand.StatementName, - batchCommand.CurrentParametersReadOnly, async, cancellationToken).ConfigureAwait(false); + // We may have a prepared statement that replaces an existing statement - close the latter first. + if (pStatement?.StatementBeingReplaced != null) + await connector.WriteClose(StatementOrPortal.Statement, pStatement.StatementBeingReplaced.Name!, async, + cancellationToken).ConfigureAwait(false); - await connector.WriteBind( - batchCommand.CurrentParametersReadOnly, - string.Empty, batchCommand.StatementName, AllResultTypesAreUnknown, - i == 0 ? UnknownResultTypeList : null, - async, cancellationToken).ConfigureAwait(false); + await connector.WriteParse(batchCommand.FinalCommandText, batchCommand.StatementName, + batchCommand.CurrentParametersReadOnly, async, cancellationToken).ConfigureAwait(false); - await connector.WriteDescribe(StatementOrPortal.Portal, Array.Empty(), async, cancellationToken).ConfigureAwait(false); - } - else - { - // The statement is already prepared, only a Bind is needed - await connector.WriteBind( - batchCommand.CurrentParametersReadOnly, - string.Empty, batchCommand.StatementName, AllResultTypesAreUnknown, - i == 0 ? UnknownResultTypeList : null, - async, cancellationToken).ConfigureAwait(false); - } + await connector.WriteBind( + batchCommand.CurrentParametersReadOnly, + string.Empty, batchCommand.StatementName, AllResultTypesAreUnknown, + i == 0 ? UnknownResultTypeList : null, + async, cancellationToken).ConfigureAwait(false); + + await connector.WriteDescribe(StatementOrPortal.Portal, Array.Empty(), async, cancellationToken) + .ConfigureAwait(false); + } + else + { + // The statement is already prepared, only a Bind is needed + await connector.WriteBind( + batchCommand.CurrentParametersReadOnly, + string.Empty, batchCommand.StatementName, AllResultTypesAreUnknown, + i == 0 ? UnknownResultTypeList : null, + async, cancellationToken).ConfigureAwait(false); + } + + await connector.WriteExecute(0, async, cancellationToken).ConfigureAwait(false); - await connector.WriteExecute(0, async, cancellationToken).ConfigureAwait(false); + if (batchCommand.AppendErrorBarrier ?? EnableErrorBarriers) + await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); - if (batchCommand.AppendErrorBarrier ?? EnableErrorBarriers) + pStatement?.RefreshLastUsed(); + } + + if (batchCommand is null || !(batchCommand.AppendErrorBarrier ?? EnableErrorBarriers)) + { await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); + } - pStatement?.RefreshLastUsed(); + if (flush) + await connector.Flush(async, cancellationToken).ConfigureAwait(false); } - - if (batchCommand is null || !(batchCommand.AppendErrorBarrier ?? EnableErrorBarriers)) + catch(Exception ex) { - await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); + connector.Break(ex); + throw; } - - if (flush) - await connector.Flush(async, cancellationToken).ConfigureAwait(false); } async Task WriteExecuteSchemaOnly(NpgsqlConnector connector, bool async, bool flush, CancellationToken cancellationToken) { - var wroteSomething = false; - var syncCaller = !async; - for (var i = 0; i < InternalBatchCommands.Count; i++) + try { - if (syncCaller && ShouldSchedule(ref async, i)) - await new TaskSchedulerAwaitable(ConstrainedConcurrencyScheduler); + var wroteSomething = false; + var syncCaller = !async; + for (var i = 0; i < InternalBatchCommands.Count; i++) + { + if (syncCaller && ShouldSchedule(ref async, i)) + await new TaskSchedulerAwaitable(ConstrainedConcurrencyScheduler); - var batchCommand = InternalBatchCommands[i]; + var batchCommand = InternalBatchCommands[i]; - if (batchCommand.PreparedStatement?.State == PreparedState.Prepared) - continue; // Prepared, we already have the RowDescription + if (batchCommand.PreparedStatement?.State == PreparedState.Prepared) + continue; // Prepared, we already have the RowDescription - await connector.WriteParse(batchCommand.FinalCommandText!, batchCommand.StatementName, - batchCommand.CurrentParametersReadOnly, - async, cancellationToken).ConfigureAwait(false); - await connector.WriteDescribe(StatementOrPortal.Statement, batchCommand.StatementName, async, cancellationToken).ConfigureAwait(false); - wroteSomething = true; - } + await connector.WriteParse(batchCommand.FinalCommandText!, batchCommand.StatementName, + batchCommand.CurrentParametersReadOnly, + async, cancellationToken).ConfigureAwait(false); + await connector.WriteDescribe(StatementOrPortal.Statement, batchCommand.StatementName, async, cancellationToken).ConfigureAwait(false); + wroteSomething = true; + } - if (wroteSomething) + if (wroteSomething) + { + await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); + if (flush) + await connector.Flush(async, cancellationToken).ConfigureAwait(false); + } + } + catch(Exception ex) { - await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); - if (flush) - await connector.Flush(async, cancellationToken).ConfigureAwait(false); + connector.Break(ex); + throw; } } } async Task SendDeriveParameters(NpgsqlConnector connector, bool async, CancellationToken cancellationToken = default) { - BeginSend(connector); - - var syncCaller = !async; - for (var i = 0; i < InternalBatchCommands.Count; i++) + try { - if (syncCaller && ShouldSchedule(ref async, i)) - await new TaskSchedulerAwaitable(ConstrainedConcurrencyScheduler); + BeginSend(connector); - var batchCommand = InternalBatchCommands[i]; + var syncCaller = !async; + for (var i = 0; i < InternalBatchCommands.Count; i++) + { + if (syncCaller && ShouldSchedule(ref async, i)) + await new TaskSchedulerAwaitable(ConstrainedConcurrencyScheduler); - await connector.WriteParse(batchCommand.FinalCommandText!, Array.Empty(), NpgsqlBatchCommand.EmptyParameters, async, cancellationToken).ConfigureAwait(false); - await connector.WriteDescribe(StatementOrPortal.Statement, Array.Empty(), async, cancellationToken).ConfigureAwait(false); - } + var batchCommand = InternalBatchCommands[i]; - await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); - await connector.Flush(async, cancellationToken).ConfigureAwait(false); + await connector.WriteParse(batchCommand.FinalCommandText!, Array.Empty(), NpgsqlBatchCommand.EmptyParameters, async, cancellationToken).ConfigureAwait(false); + await connector.WriteDescribe(StatementOrPortal.Statement, Array.Empty(), async, cancellationToken).ConfigureAwait(false); + } + + await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); + await connector.Flush(async, cancellationToken).ConfigureAwait(false); + } + catch(Exception ex) + { + connector.Break(ex); + throw; + } } async Task SendPrepare(NpgsqlConnector connector, bool async, CancellationToken cancellationToken = default) { - BeginSend(connector); - - var syncCaller = !async; - for (var i = 0; i < InternalBatchCommands.Count; i++) + try { - if (syncCaller && ShouldSchedule(ref async, i)) - await new TaskSchedulerAwaitable(ConstrainedConcurrencyScheduler); + BeginSend(connector); - var batchCommand = InternalBatchCommands[i]; - var pStatement = batchCommand.PreparedStatement; + var syncCaller = !async; + for (var i = 0; i < InternalBatchCommands.Count; i++) + { + if (syncCaller && ShouldSchedule(ref async, i)) + await new TaskSchedulerAwaitable(ConstrainedConcurrencyScheduler); - // A statement may be already prepared, already in preparation (i.e. same statement twice - // in the same command), or we can't prepare (overloaded SQL) - if (!batchCommand.IsPreparing) - continue; + var batchCommand = InternalBatchCommands[i]; + var pStatement = batchCommand.PreparedStatement; - // We may have a prepared statement that replaces an existing statement - close the latter first. - var statementToClose = pStatement!.StatementBeingReplaced; - if (statementToClose != null) - await connector.WriteClose(StatementOrPortal.Statement, statementToClose.Name!, async, cancellationToken).ConfigureAwait(false); + // A statement may be already prepared, already in preparation (i.e. same statement twice + // in the same command), or we can't prepare (overloaded SQL) + if (!batchCommand.IsPreparing) + continue; + + // We may have a prepared statement that replaces an existing statement - close the latter first. + var statementToClose = pStatement!.StatementBeingReplaced; + if (statementToClose != null) + await connector.WriteClose(StatementOrPortal.Statement, statementToClose.Name!, async, cancellationToken) + .ConfigureAwait(false); + + await connector.WriteParse(batchCommand.FinalCommandText!, pStatement.Name!, batchCommand.CurrentParametersReadOnly, async, + cancellationToken).ConfigureAwait(false); + await connector.WriteDescribe(StatementOrPortal.Statement, pStatement.Name!, async, cancellationToken) + .ConfigureAwait(false); + } - await connector.WriteParse(batchCommand.FinalCommandText!, pStatement.Name!, batchCommand.CurrentParametersReadOnly, async, - cancellationToken).ConfigureAwait(false); - await connector.WriteDescribe(StatementOrPortal.Statement, pStatement.Name!, async, cancellationToken).ConfigureAwait(false); + await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); + await connector.Flush(async, cancellationToken).ConfigureAwait(false); + } + catch(Exception ex) + { + connector.Break(ex); + throw; } - - await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); - await connector.Flush(async, cancellationToken).ConfigureAwait(false); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -1205,19 +1241,28 @@ bool ShouldSchedule(ref bool async, int indexOfStatementInBatch) async Task SendClose(NpgsqlConnector connector, bool async, CancellationToken cancellationToken = default) { - BeginSend(connector); + try + { + BeginSend(connector); - foreach (var batchCommand in InternalBatchCommands) + foreach (var batchCommand in InternalBatchCommands) + { + if (!batchCommand.IsPrepared) + continue; + // No need to force async here since each statement takes no more than 20 bytes + await connector.WriteClose(StatementOrPortal.Statement, batchCommand.StatementName, async, cancellationToken) + .ConfigureAwait(false); + batchCommand.PreparedStatement!.State = PreparedState.BeingUnprepared; + } + + await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); + await connector.Flush(async, cancellationToken).ConfigureAwait(false); + } + catch(Exception ex) { - if (!batchCommand.IsPrepared) - continue; - // No need to force async here since each statement takes no more than 20 bytes - await connector.WriteClose(StatementOrPortal.Statement, batchCommand.StatementName, async, cancellationToken).ConfigureAwait(false); - batchCommand.PreparedStatement!.State = PreparedState.BeingUnprepared; + connector.Break(ex); + throw; } - - await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); - await connector.Flush(async, cancellationToken).ConfigureAwait(false); } #endregion