From 6d010535612f0e4917b5a2765bdb8390a81c5af4 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Fri, 23 Oct 2020 13:47:45 +0300 Subject: [PATCH] Prepend and quote savepoints Closes #2892 Closes #3262 --- src/Npgsql/NpgsqlConnection.cs | 14 ++-- src/Npgsql/NpgsqlConnector.cs | 14 ++-- src/Npgsql/NpgsqlTransaction.cs | 91 +++++++++++++++++------ test/Npgsql.Tests/ReaderTests.cs | 4 +- test/Npgsql.Tests/Support/PgServerMock.cs | 30 +++++++- test/Npgsql.Tests/TransactionTests.cs | 36 ++++++++- 6 files changed, 144 insertions(+), 45 deletions(-) diff --git a/src/Npgsql/NpgsqlConnection.cs b/src/Npgsql/NpgsqlConnection.cs index 551b27a689..71b112d903 100644 --- a/src/Npgsql/NpgsqlConnection.cs +++ b/src/Npgsql/NpgsqlConnection.cs @@ -575,14 +575,12 @@ async ValueTask BeginTransaction(IsolationLevel level, bool a try { - // Note that beginning a transaction doesn't actually send anything to the backend - // (only prepends), so strictly speaking we don't have to start a user action. - // However, we do this for consistency as if we did (for the checks and exceptions) - using (connector.StartUserAction()) - { - connector.Transaction.Init(level); - return connector.Transaction; - } + // Note that beginning a transaction doesn't actually send anything to the backend (only prepends), so strictly speaking we + // don't have to start a user action. However, we do this for consistency as if we did (for the checks and exceptions) + using var _ = connector.StartUserAction(); + + connector.Transaction.Init(level); + return connector.Transaction; } catch { diff --git a/src/Npgsql/NpgsqlConnector.cs b/src/Npgsql/NpgsqlConnector.cs index dd021f69d1..b71241ac7e 100644 --- a/src/Npgsql/NpgsqlConnector.cs +++ b/src/Npgsql/NpgsqlConnector.cs @@ -120,7 +120,7 @@ sealed partial class NpgsqlConnector : IDisposable /// The number of messages that were prepended to the current message chain, but not yet sent. /// Note that this only tracks messages which produce a ReadyForQuery message /// - int _pendingPrependedResponses; + internal int PendingPrependedResponses { get; set; } internal NpgsqlDataReader? CurrentReader; @@ -1022,7 +1022,7 @@ async Task MultiplexingReadLoop() /// internal void PrependInternalMessage(byte[] rawMessage, int responseMessageCount) { - _pendingPrependedResponses += responseMessageCount; + PendingPrependedResponses += responseMessageCount; var t = WritePregenerated(rawMessage); Debug.Assert(t.IsCompleted, "Could not fully write pregenerated message into the buffer"); @@ -1061,7 +1061,7 @@ internal ValueTask ReadMessage(bool async, DataRowLoadingMode d bool attemptPostgresCancellation = true, CancellationToken cancellationToken = default) { - if (_pendingPrependedResponses > 0 || + if (PendingPrependedResponses > 0 || dataRowLoadingMode != DataRowLoadingMode.NonSequential || readingNotifications || ReadBuffer.ReadBytesLeft < 5) @@ -1103,13 +1103,13 @@ internal ValueTask ReadMessage(bool async, DataRowLoadingMode d CancellationToken cancellationToken2 = default) { // First read the responses of any prepended messages. - if (_pendingPrependedResponses > 0 && !isReadingPrependedMessage) + if (PendingPrependedResponses > 0 && !isReadingPrependedMessage) { try { // TODO: There could be room for optimization here, rather than the async call(s) ReadBuffer.Timeout = TimeSpan.FromMilliseconds(InternalCommandTimeout); - for (; _pendingPrependedResponses > 0; _pendingPrependedResponses--) + for (; PendingPrependedResponses > 0; PendingPrependedResponses--) await ReadMessageLong(DataRowLoadingMode.Skip, readingNotifications2: false, isReadingPrependedMessage: true, attemptPostgresCancellation2: false, cancellationToken2); } @@ -1795,7 +1795,7 @@ internal async Task Reset(bool async, CancellationToken cancellationToken = defa // Our buffer may contain unsent prepended messages (such as BeginTransaction), clear it out completely WriteBuffer.Clear(); - _pendingPrependedResponses = 0; + PendingPrependedResponses = 0; // We may have allocated an oversize read buffer, switch back to the original one // TODO: Replace this with array pooling, #2326 @@ -2139,7 +2139,7 @@ internal async Task ExecuteInternalCommand(byte[] data, bool async, Cancellation { Debug.Assert(State != ConnectorState.Ready, "Forgot to start a user action..."); - Log.Trace($"Executing internal pregenerated command", Id); + Log.Trace("Executing internal pregenerated command", Id); await WritePregenerated(data, async, cancellationToken); await Flush(async, cancellationToken); diff --git a/src/Npgsql/NpgsqlTransaction.cs b/src/Npgsql/NpgsqlTransaction.cs index 27fa2bbeca..e7661d93d9 100644 --- a/src/Npgsql/NpgsqlTransaction.cs +++ b/src/Npgsql/NpgsqlTransaction.cs @@ -2,6 +2,7 @@ using System.Data; using System.Data.Common; using System.Diagnostics; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Npgsql.Logging; @@ -184,40 +185,63 @@ public override Task RollbackAsync(CancellationToken cancellationToken = default #region Savepoints - async Task Save(string name, bool async, CancellationToken cancellationToken = default) + /// + /// Creates a transaction save point. + /// + /// The name of the savepoint. + /// + /// This method does not cause a database roundtrip to be made. The savepoint creation statement will instead be sent along with + /// the next command. + /// +#if NET + public override void Save(string name) +#else + public void Save(string name) +#endif { if (name == null) throw new ArgumentNullException(nameof(name)); if (string.IsNullOrWhiteSpace(name)) throw new ArgumentException("name can't be empty", nameof(name)); - if (name.Contains(";")) - throw new ArgumentException("name can't contain a semicolon"); CheckReady(); if (!_connector.DatabaseInfo.SupportsTransactions) return; - using (_connector.StartUserAction()) - { - Log.Debug($"Creating savepoint {name}", _connector.Id); - await _connector.ExecuteInternalCommand($"SAVEPOINT {name}", async, cancellationToken); - } - } - /// - /// Creates a transaction save point. - /// - /// The name of the savepoint. -#if NET - public override void Save(string name) => Save(name, false).GetAwaiter().GetResult(); -#else - public void Save(string name) => Save(name, false).GetAwaiter().GetResult(); -#endif + // Note that creating a savepoint doesn't actually send anything to the backend (only prepends), so strictly speaking we don't + // have to start a user action. However, we do this for consistency as if we did (for the checks and exceptions) + using var _ = _connector.StartUserAction(); + + Log.Debug($"Creating savepoint {name}", _connector.Id); + + if (RequiresQuoting(name)) + name = $"\"{name.Replace("\"", "\"\"")}\""; + + // Note: savepoint names are PostgreSQL identifiers, and so limited by default to 63 characters. + // Since we are prepending, we assume below that the statement will always fit in the buffer. + _connector.WriteBuffer.WriteByte(FrontendMessageCode.Query); + _connector.WriteBuffer.WriteInt32( + sizeof(int) + // Message length (including self excluding code) + _connector.TextEncoding.GetByteCount("SAVEPOINT ") + + _connector.TextEncoding.GetByteCount(name) + + sizeof(byte)); // Null terminator + + _connector.WriteBuffer.WriteString("SAVEPOINT "); + _connector.WriteBuffer.WriteString(name); + _connector.WriteBuffer.WriteByte(0); + + _connector.PendingPrependedResponses += 2; + } /// /// Creates a transaction save point. /// /// The name of the savepoint. /// The token to monitor for cancellation requests. The default value is . + /// + /// This method does not cause a database roundtrip to be made, and will therefore always complete synchronously. + /// The savepoint creation statement will instead be sent along with the next command. + /// #if NET public override Task SaveAsync(string name, CancellationToken cancellationToken = default) #else @@ -226,8 +250,8 @@ public Task SaveAsync(string name, CancellationToken cancellationToken = default { if (cancellationToken.IsCancellationRequested) return Task.FromCanceled(cancellationToken); - using (NoSynchronizationContextScope.Enter()) - return Save(name, true, cancellationToken); + Save(name); + return Task.CompletedTask; } async Task Rollback(string name, bool async, CancellationToken cancellationToken = default) @@ -236,8 +260,6 @@ async Task Rollback(string name, bool async, CancellationToken cancellationToken throw new ArgumentNullException(nameof(name)); if (string.IsNullOrWhiteSpace(name)) throw new ArgumentException("name can't be empty", nameof(name)); - if (name.Contains(";")) - throw new ArgumentException("name can't contain a semicolon"); CheckReady(); if (!_connector.DatabaseInfo.SupportsTransactions) @@ -245,6 +267,10 @@ async Task Rollback(string name, bool async, CancellationToken cancellationToken using (_connector.StartUserAction()) { Log.Debug($"Rolling back savepoint {name}", _connector.Id); + + if (RequiresQuoting(name)) + name = $"\"{name.Replace("\"", "\"\"")}\""; + await _connector.ExecuteInternalCommand($"ROLLBACK TO SAVEPOINT {name}", async, cancellationToken); } } @@ -283,8 +309,6 @@ async Task Release(string name, bool async, CancellationToken cancellationToken throw new ArgumentNullException(nameof(name)); if (string.IsNullOrWhiteSpace(name)) throw new ArgumentException("name can't be empty", nameof(name)); - if (name.Contains(";")) - throw new ArgumentException("name can't contain a semicolon"); CheckReady(); if (!_connector.DatabaseInfo.SupportsTransactions) @@ -292,6 +316,10 @@ async Task Release(string name, bool async, CancellationToken cancellationToken using (_connector.StartUserAction()) { Log.Debug($"Releasing savepoint {name}", _connector.Id); + + if (RequiresQuoting(name)) + name = $"\"{name.Replace("\"", "\"\"")}\""; + await _connector.ExecuteInternalCommand($"RELEASE SAVEPOINT {name}", async, cancellationToken); } } @@ -401,6 +429,21 @@ void CheckReady() throw new InvalidOperationException("This NpgsqlTransaction has completed; it is no longer usable."); } + static bool RequiresQuoting(string identifier) + { + Debug.Assert(identifier.Length > 0); + + var first = identifier[0]; + if (first != '_' && !char.IsLower(first)) + return true; + + foreach (var c in identifier.AsSpan(1)) + if (c != '_' && c != '$' && !char.IsLower(c) && !char.IsDigit(c)) + return true; + + return false; + } + #endregion } } diff --git a/test/Npgsql.Tests/ReaderTests.cs b/test/Npgsql.Tests/ReaderTests.cs index a37e798e94..2d4421ab8f 100644 --- a/test/Npgsql.Tests/ReaderTests.cs +++ b/test/Npgsql.Tests/ReaderTests.cs @@ -1582,7 +1582,7 @@ await pgMock Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); } - await pgMock.WriteScalarResponse(1); + await pgMock.WriteScalarResponseAndFlush(1); Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); } @@ -1634,7 +1634,7 @@ await pgMock Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); } - await pgMock.WriteScalarResponse(1); + await pgMock.WriteScalarResponseAndFlush(1); Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); } diff --git a/test/Npgsql.Tests/Support/PgServerMock.cs b/test/Npgsql.Tests/Support/PgServerMock.cs index b30ce89bd7..4784a4a053 100644 --- a/test/Npgsql.Tests/Support/PgServerMock.cs +++ b/test/Npgsql.Tests/Support/PgServerMock.cs @@ -72,7 +72,7 @@ internal async Task SkipMessage() _readBuffer.Skip(len - 4); } - internal async Task ReadMessageType(byte expectedCode) + internal async Task ExpectMessage(byte expectedCode) { CheckDisposed(); @@ -84,13 +84,39 @@ internal async Task ReadMessageType(byte expectedCode) _readBuffer.Skip(len - 4); } + internal Task ExpectExtendedQuery() + => ExpectMessages( + FrontendMessageCode.Parse, + FrontendMessageCode.Bind, + FrontendMessageCode.Describe, + FrontendMessageCode.Execute, + FrontendMessageCode.Sync); + + internal async Task ExpectMessages(params byte[] expectedCodes) + { + foreach (var expectedCode in expectedCodes) + await ExpectMessage(expectedCode); + } + + internal async Task ExpectSimpleQuery(string expectedSql) + { + CheckDisposed(); + + await _readBuffer.EnsureAsync(5); + var actualCode = _readBuffer.ReadByte(); + Assert.That(actualCode, Is.EqualTo(FrontendMessageCode.Query), $"Expected message of type Query but got '{(char)actualCode}'"); + _ = _readBuffer.ReadInt32(); + var actualSql = _readBuffer.ReadNullTerminatedString(); + Assert.That(actualSql, Is.EqualTo(expectedSql)); + } + internal Task FlushAsync() { CheckDisposed(); return _writeBuffer.Flush(async: true); } - internal Task WriteScalarResponse(int value) + internal Task WriteScalarResponseAndFlush(int value) => WriteParseComplete() .WriteBindComplete() .WriteRowDescription(new FieldDescription(PostgresTypeOIDs.Int4)) diff --git a/test/Npgsql.Tests/TransactionTests.cs b/test/Npgsql.Tests/TransactionTests.cs index bfbbe8d150..4d3347c6b3 100644 --- a/test/Npgsql.Tests/TransactionTests.cs +++ b/test/Npgsql.Tests/TransactionTests.cs @@ -1,6 +1,9 @@ using System; +using System.Buffers.Binary; using System.Data; using System.Threading.Tasks; +using Npgsql.BackendMessages; +using Npgsql.Tests.Support; using Npgsql.Util; using NUnit.Framework; using static Npgsql.Tests.TestUtil; @@ -368,11 +371,40 @@ public async Task SavepointAsync() } [Test] - public async Task SavepointWithSemicolon() + public async Task SavepointQuoted() { await using var conn = await OpenConnectionAsync(); await using var tx = conn.BeginTransaction(); - Assert.That(() => tx.Save("a;b"), Throws.Exception.TypeOf()); + tx.Save("a;b"); + tx.Rollback("a;b"); + } + + [Test(Description = "Makes sure that creating a savepoint doesn't perform an additional roundtrip, but prepends to the next command")] + public async Task SavepointPrepends() + { + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + using var _ = CreateTempPool(postmasterMock.ConnectionString, out var connectionString); + await using var conn = await OpenConnectionAsync(connectionString); + var pgMock = await postmasterMock.WaitForServerConnection(); + + using var tx = conn.BeginTransaction(); + var saveTask = tx.SaveAsync("foo"); + Assert.That(saveTask.Status, Is.EqualTo(TaskStatus.RanToCompletion)); + + // If we're here, SaveAsync above didn't wait for any response, which is the right behavior + + await pgMock + .WriteCommandComplete() + .WriteReadyForQuery() // BEGIN response + .WriteCommandComplete() + .WriteReadyForQuery() // SAVEPOINT response + .WriteScalarResponseAndFlush(1); + + await conn.ExecuteScalarAsync("SELECT 1"); + + await pgMock.ExpectSimpleQuery("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED"); + await pgMock.ExpectSimpleQuery("SAVEPOINT foo"); + await pgMock.ExpectExtendedQuery(); } [Test, Description("Check IsCompleted before, during and after a normal committed transaction")]